diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..0346614583921620cf9c06433641c8b685d936aa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py @@ -0,0 +1,299 @@ +""" +This module implements CUDA graphs support for TorchDynamo backends. + +CUDA graphs allow for capturing and replaying GPU operations, which can significantly +reduce CPU overhead in GPU-accelerated PyTorch models. This module provides: + +- CUDA graph creation and management for both forward and backward passes +- Input mutation detection and handling +- Device compatibility checking +- Stack trace management for debugging +- Integration with TorchInductor's cudagraph trees + +The backend supports two main modes: +1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization +2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking + +Key components: +- CudagraphsBackend: Main backend class for CUDA graph integration +- Mutation detection utilities to ensure graph safety +- Device mapping and compatibility checks +- Stack trace collection for debugging +""" + +import functools +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Any, Optional + +import torch +import torch.fx +from torch._dynamo import config +from torch._dynamo.backends.common import aot_autograd +from torch._dynamo.backends.debugging import boxed_nop +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + check_multiple_devices_or_any_cpu_nodes, + format_default_skip_message, + get_mutation_stack_trace, + get_placeholder_info, + log_cudagraph_skip_and_bump_counter, +) +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + get_first_incompatible_cudagraph_node, + num_fw_fixed_arguments, + output_node, +) +from torch.multiprocessing.reductions import StorageWeakRef + +from .registry import register_backend + + +def find_input_mutations(g: torch.fx.Graph) -> set[int]: + def meta_fk(meta: dict[str, Any]) -> Any: + return meta["val"] if "val" in meta else meta["fake_result"] + + inputs = defaultdict(set) + input_idx = 0 + mutated_inputs = set() + for n in g.nodes: + if n.op == "placeholder": + if isinstance(meta_fk(n.meta), torch.Tensor): + inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) + input_idx += 1 + elif n.op == "call_function": + if not hasattr(n.target, "_schema"): + continue + + schema = n.target._schema + for i, arg in enumerate(schema.arguments): + if i < len(n.args): + argument = n.args[i] + else: + if arg.name not in n.kwargs: + continue + argument = n.kwargs[arg.name] + mut_arg = False + if arg.alias_info: + if arg.alias_info.is_write: + mut_arg = True + if mut_arg: + # TODO: not correct for args that contain tensors in a struct + # like list + mutated_inputs |= inputs[ + StorageWeakRef(meta_fk(argument.meta)._typed_storage()) + ] + + # TODO: error on unrecognized nodes + return mutated_inputs + + +def get_device_node_mapping( + gm: torch.fx.GraphModule, +) -> dict[torch.device, torch.fx.Node]: + device_node_mapping: dict[torch.device, torch.fx.Node] = {} + for n in gm.graph.nodes: + t = n.meta.get("val", None) + if isinstance(t, torch.Tensor) and t.device not in device_node_mapping: + device_node_mapping[t.device] = n + return device_node_mapping + + +def check_for_mutation_ignore_cuda_graph_managed_tensor( + aot_model: torch.fx.GraphModule, num_fixed: int +) -> Optional[str]: + mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) + if not mutation_indices: + return None + + placeholders = get_placeholder_info(aot_model.graph) + return get_mutation_stack_trace(placeholders, mutation_indices) + + +def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]: + if not config.cudagraph_backend_support_input_mutation: + if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor( + aot_model, num_fixed + ): + return mut_skip + + if skip := check_multiple_devices_or_any_cpu_nodes( + get_device_node_mapping(aot_model) + ): + return skip + + if node := get_first_incompatible_cudagraph_node(aot_model): + return format_default_skip_message(f"incompatible op ({node.name})") + + return None + + +def get_device_index(gm: torch.fx.GraphModule) -> int: + device = next(iter(get_device_node_mapping(gm))) + assert device.type == "cuda" + return device.index + + +def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]: + output = output_node(gm) + assert len(output.args) == 1 + args = output.args[0] + if not hasattr(args, "__iter__"): + return [] + return [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in args # type: ignore[union-attr] + ] + + +def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any: + from torch._inductor.cudagraph_trees import cudagraphify_impl + + do_cudagraphs = BoxedBool(True) + boxed_device_index = BoxedDeviceIndex(None) + + def forward_cudagraphs( + aot_model: torch.fx.GraphModule, + aot_inputs: list[Any], + is_inference: bool = False, + ) -> Any: + interp = boxed_nop(aot_model, aot_inputs) + fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) + if skip_msg := check_for_skip(aot_model, fixed): + BoxedBool.disable(do_cudagraphs) + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {skip_msg}" + ) + return interp + + boxed_device_index.set(get_device_index(aot_model)) + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=boxed_device_index.value, + is_backward=False, + is_inference=False, # Q: should forward is_inference here? + stack_traces=get_stack_traces(aot_model), + placeholders=get_placeholder_info(aot_model.graph), + mutated_input_idxs=find_input_mutations(aot_model.graph), + ) + out._boxed_call = True # type: ignore[attr-defined] + return out + + def backward_cudagraphs( + aot_model: torch.fx.GraphModule, aot_inputs: list[Any] + ) -> Any: + interp = boxed_nop(aot_model, aot_inputs) + if not do_cudagraphs: + return aot_model + + fixed = count_tangents(aot_model) + if skip_msg := check_for_skip(aot_model, fixed): + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {skip_msg}" + ) + + # See [Backward Generation Handling] + device_idx = boxed_device_index.value + if device_idx is None: + device_idx = 0 # Default to device 0 if not set + manager = torch._inductor.cudagraph_trees.get_manager( + device_idx, create_if_none_exists=False + ) + assert manager is not None + + def fn(inputs: list[Any]) -> Any: + # pyrefly: ignore [missing-attribute] + manager.set_to_running_backward() + return aot_model(inputs) + + fn._boxed_call = True # type: ignore[attr-defined] + return fn + + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=get_device_index(aot_model), + is_backward=True, + is_inference=False, + stack_traces=get_stack_traces(aot_model), + placeholders=get_placeholder_info(aot_model.graph), + mutated_input_idxs=find_input_mutations(aot_model.graph), + ) + out._boxed_call = True # type: ignore[attr-defined] + return out + + aot_cudagraphs = aot_autograd( + fw_compiler=forward_cudagraphs, + bw_compiler=backward_cudagraphs, + inference_compiler=functools.partial(forward_cudagraphs, is_inference=True), + keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation, + ) + return aot_cudagraphs(dynamo_model, dynamo_inputs) + + +class CudagraphsBackend: + compiler_name = "cudagraphs" + + @staticmethod + def reset() -> None: + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + + @staticmethod + def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any: + return cudagraphs(model, inputs) + + +# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful +# for debugging and can serve as a perf baseline. +register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) + + +def cudagraphs_inner( + model: Callable[..., Any], + inputs: Sequence[Any], + copy_outputs: bool = True, + copy_inputs: bool = True, +) -> Callable[..., Sequence[Any]]: + """This isn't registered as a backend, but is used in some benchmarks""" + assert isinstance(inputs, (list, tuple)) + if copy_inputs: + static_inputs = [torch.zeros_like(x) for x in inputs] + else: + static_inputs = list(inputs) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + model(*inputs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + static_outputs = model(*static_inputs) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + def run(*new_inputs: Any) -> Sequence[Any]: + assert len(static_inputs) == len(new_inputs) + if copy_inputs: + for dst, src in zip(static_inputs, new_inputs): + dst.copy_(src) + graph.replay() + if copy_outputs: + return [x.clone() for x in static_outputs] + else: + return static_outputs + + return run diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..e53becd884bbaf7f4d7c876cffb739c59a1717bf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py @@ -0,0 +1,621 @@ +""" +This module implements distributed training optimizations for TorchDynamo backends. + +It provides functionality to optimize models wrapped in DistributedDataParallel (DDP) +by intelligently splitting compiled graphs to align with DDP's gradient synchronization +boundaries. Key features include: + +- Graph partitioning based on parameter bucket sizes +- Optimization of allreduce operations for distributed training +- Support for parameter ignoring and buffer handling +- Submodule compilation and management +- Debugging utilities for distributed training + +The main component is the DDPOptimizer class, which handles graph splitting and +recompilation to enable efficient distributed training while maintaining the benefits +of compilation. +""" + +import logging +import traceback +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Optional, TYPE_CHECKING +from unittest import mock + +import torch +from torch import fx +from torch._dynamo.backends.registry import CompiledFn, CompilerFn +from torch._dynamo.output_graph import GraphCompileReason +from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode +from torch._logging import trace_structured +from torch.fx.node import Node + + +if TYPE_CHECKING: + from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta + + +# Regular log messages should go through 'log'. +# ddp_graph_log is a separate artifact logger reserved for dumping graphs. +# See docs/source/logging.rst for more info. +log = logging.getLogger(__name__) +ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs") + + +def args_str(args: Any) -> str: + # a debug helper + if torch.is_tensor(args): + return f"T[{args.shape}]" + elif isinstance(args, tuple): + return f"tuple({', '.join([args_str(x) for x in args])})" + elif isinstance(args, list): + return f"list({', '.join([args_str(x) for x in args])})" + else: + return str(args) + + +@dataclass +class Bucket: + size: int = 0 + params: list[str] = field(default_factory=list) + nodes: list[fx.Node] = field(default_factory=list) + + # param_ids is just used for unit testing + param_ids: list[int] = field(default_factory=list) + + # keep track of any buckets that were extended for logging purposes + opcount_increased_to_capture_external_output: int = 0 + paramsize_before_opcount_increase: int = 0 + + +def bucket_has_external_output(bucket: Bucket) -> bool: + nodes_in_bucket = set() + # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards + # so we don't reverse it here + for node in bucket.nodes: + # assume node.op != output, since those are filtered in the original iteration + nodes_in_bucket.add(node) + for user in node.users: + if user not in nodes_in_bucket: + return True + return False + + +def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None: + headers = ("Index", "Size (b)", "Param Names") + rows: list[tuple[Optional[int], Optional[int], str]] = [] + extended_buckets = [] + for idx, bucket in enumerate(reversed(buckets)): + if len(bucket.params) > 0: + rows.append((idx, bucket.size, bucket.params[0])) + rows.extend((None, None, param) for param in bucket.params[1:]) + if bucket.opcount_increased_to_capture_external_output > 0: + extended_buckets.append( + ( + idx, + bucket.opcount_increased_to_capture_external_output, + bucket.size - bucket.paramsize_before_opcount_increase, + ) + ) + + if rows: + log.info( + "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.", + bucket_bytes_cap, + len(buckets), + ) + + if extended_buckets: + log.warning( + "Some buckets were extended beyond their requested parameter capacities" + " in order to ensure each subgraph has an output node, required for fx graph partitioning." + " This can be the case when a subgraph would have only contained nodes performing inplace mutation," + " and returning no logical outputs. This should not be a problem, unless it results in too few graph" + " partitions for optimal DDP performance." + ) + + try: + from tabulate import tabulate + + log.debug( + "\nDDPOptimizer produced the following bucket assignments:\n%s", + tabulate(rows, headers=headers, tablefmt="simple_grid"), + ) + + if extended_buckets: + log.warning( + "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s", + tabulate( + extended_buckets, + headers=("Index", "Extra Ops", "Extra Param Size (b)"), + tablefmt="simple_grid", + ), + ) + except ImportError: + log.debug( + "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information." + ) + else: + log.debug("DDPOptimizer captured no parameters and did not split this graph.") + + +def has_higher_order_op(gm: fx.GraphModule) -> bool: + # Check if there is a higher order op in the graph + for node in gm.graph.nodes: + if node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if isinstance(maybe_param, torch.fx.GraphModule): + return True + return False + + +def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None: + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384 + module.meta = orig_gm.meta + module._param_name_to_source = orig_gm._param_name_to_source + + +def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None: + name_to_dynamo_source = {} + for node in orig_gm.graph.find_nodes(op="placeholder"): + name_to_dynamo_source[node.name] = node._dynamo_source + + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + for node in module.graph.find_nodes(op="placeholder"): + # non-placeholder in original_gm may become placeholder in submodules + node._dynamo_source = name_to_dynamo_source.get(node.name, None) + + +class DDPOptimizerContext: + def __init__(self) -> None: + self.curr_bucket: int = -1 + self.metadata_per_bucket: list[ViewAndMutationMeta] = [] + + +# compile each of the partitioned submodules using the user-provided compiler +class SubmodCompiler(torch.fx.interpreter.Interpreter): + def __init__( + self, + module: fx.GraphModule, + compiler: CompilerFn, + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, + ) -> None: + super().__init__(module) + self.compiler = compiler + self.fake_mode = fake_mode + # See Note [DDPOptimizer and fw_metadata] + ctx = torch._guards.TracingContext.try_get() + if ctx is not None: + ctx.ddp_optimizer_ctx = DDPOptimizerContext() + + def compile_submod( + self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any + ) -> Any: + """ + Compile the submodule, + using a wrapper to make sure its output is always a tuple, + which is required by AotAutograd based compilers + """ + assert len(kwargs) == 0, "We assume only args for these modules" + + class WrapperModule(torch.nn.Module): + def __init__( + self, submod: Callable[..., Any], unwrap_singleton_tuple: bool + ) -> None: + super().__init__() + self.submod = submod + self.unwrap_singleton_tuple = unwrap_singleton_tuple + + def forward(self, *args: Any) -> Any: + x = self.submod(*args) + # TODO(whc) + # for some reason the isinstance check is necessary if I split one node per submod + # - even though I supposedly wrapped the output in a tuple in those cases, the real + # compiled module was still returning a tensor + if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)): + return x[0] + return x + + unwrap_singleton_tuple = False + for sn in input_mod.graph.nodes: + if sn.op == "output": + if not isinstance(sn.args[0], tuple): + unwrap_singleton_tuple = True + sn.args = (sn.args,) + + input_mod.recompile() + input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment] + "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])." + " Set `torch._dynamo.config.optimize_ddp = False` to disable.", + [ + # it's close to useless to get a real stacktrace here, and quite verbose. + traceback.FrameSummary(__file__, 0, "DDPOptimizer"), + ], + ) + + wrapper = WrapperModule( + self.compiler(input_mod, args), + unwrap_singleton_tuple, + ) + return wrapper + + # Note: + # + # The way distributed works today around fake tensors can be somewhat confusing. + # Some of these codepaths are shared in both runtime, and compile time. The presence + # of a fake_mode, read off of fake tensor inputs, dictates how we will operate. + # + # A few things to keep in mind: + # + # 1) We invoke `compile_submod` with a real module. The output of that gets stored + # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`. + # + # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the + # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it. + # + # 3) Fake tensors should always be around during compile time. + # + # 4) Fake tensors should never be around at runtime. + # + # 5) We end up with a compilation mode that takes a real submodule and fake tensors, + # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd] + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + new_args = [] + assert self.fake_mode + for arg in args: + if isinstance(arg, torch.Tensor) and not isinstance( + arg, torch._subclasses.FakeTensor + ): + new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode)) + else: + new_args.append(arg) + + log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args)) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + if n.op == "call_module": + real_mod = self.fetch_attr(str(n.target)) + if self.fake_mode: + curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode) + else: + curr_submod = real_mod + + ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph) + + # When calling the compiler on the submod, inputs (new_args) are expected to + # be FakeTensors already since Dynamo would have made them FakeTensors in the + # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors, + # since this wrapping happens during compilation + + # Note: Returning Fake Tensors on First AOT Autograd Call + # + # Inductor will optimize strides of outputs when it deems it profitable. + # For instance, converting to channels last. When we split the graph here + # into multiple inductor compilations, we need to make sure that the + # output strides of one compilation is appropriately passed to the subsequent + # compilations. However, the mapping from inductor output to dynamo output + # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing, + # subclass handling, etc. In order to replay all this logic we set a flag such that + # the first invocation of inductor in aot_autograd will return Fake Tensors with + # appropriate strides. Then, all of aot autograd's runtime logic is replayed. + # This gives us the appropriately strided outputs here which will reflect runtime strides. + + class FakeifyFirstAOTInvocationGuard: + def __init__(self) -> None: + self.tc = torch._guards.TracingContext.try_get() + assert self.tc + self.tc.fakify_first_call = True + + def __del__(self) -> None: + self.tc.fakify_first_call = False # type: ignore[union-attr] + + # For aot_eager and other backends, tracing context is not set + has_tracing_context = torch._guards.TracingContext.try_get() is not None + if has_tracing_context: + g = FakeifyFirstAOTInvocationGuard() # noqa: F841 + + from torch._dynamo.utils import counters + + init = counters["aot_autograd"]["total"] + compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs) + + # TODO - better way of doing this? + # Only aot autograd handles fakifying first call + invoked_aot_autograd = init != counters["aot_autograd"]["total"] + + # We update the original (outer) graph with a call into the compiled module + # instead of the uncompiled one. + self.module.delete_submodule(n.target) # type: ignore[operator] + n.target = "compiled_" + n.target # type: ignore[operator] + self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator] + + # Finally, we have to produce inputs for use compiling the next submodule, + # and these need to be FakeTensors, so we execute the module under fake_mode + # Because parameters are not fake we patch fake tensor mode to allow non fake inputs + with ( + self.fake_mode, + mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True), + ): + if has_tracing_context and invoked_aot_autograd: + tracing_ctx = torch._guards.TracingContext.try_get() + assert tracing_ctx is not None + # DDPOptimizer maintains 1 dynamo graph -> N AOT graphs + # Dynamo only has 1 tracing context, so it needs to maintain all N AOT metadata instances + ddp_ctx = tracing_ctx.ddp_optimizer_ctx + assert ddp_ctx is not None + assert tracing_ctx.fw_metadata is not None + ddp_ctx.curr_bucket += 1 + ddp_ctx.metadata_per_bucket.append(tracing_ctx.fw_metadata) + + out = compiled_submod_real(*new_args, **kwargs) + # output should be fake or subclass + assert all( + (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor) + for t in (out if isinstance(out, (list, tuple)) else [out]) + ) + return out + else: + return curr_submod(*new_args, **kwargs) + else: + # placeholder or output nodes don't need to get compiled, just executed + return getattr(self, n.op)(n.target, new_args, kwargs) + + +class DDPOptimizer: + """Note [DDPOptimizer] + DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP), + breaking the dynamo graph into chunks to compile separately, with the breaks aligning to + the boundaries of gradient-allreduce buckets chosen by DDP. + + Background/Motivation + - DDP uses allreduce collectives to synchronize partial gradients computed on different workers + - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce + - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready + at around the same time during backward and thus can share the same allreduce efficiently + - Allreduces must overlap with backward compute for optimal training performance + - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which + operates when individual grads become 'ready' + - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the + autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole + fused backward function executes, preventing any overlap of compute and communication + + Algorithm + - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse + this graph in reverse order to determine the true order that gradients will become ready during backward. + - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started + and a graph break introduced + - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together + into an outer module that is returned to the user + + Notes + - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP, + and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does + in eager. + - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently + produce splits that do not necessarily align with the buckets used by DDP. This should result in performance + degradation approaching the baseline case where graph-splits are not used, but not worse. + - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the + subgraphs being compiled + - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers + left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are + also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP, + it is not catastrophic but could impact performance by choosing sub-optimal bucket splits. + - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients, + and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by + DDPOptimizer) + + Debugging + - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb. + - In many cases, the log messages are helpful (they show bucket size assignments)- + just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'. + - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model + in a single process (or with torchrun, in multiple processes) + + Args: + bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be + set to match the equivalent parameter on the original DDP module. + + backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph. + + first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP + special-cases the first bucket size since it is sometimes optimal to start a small allreduce early. + + """ + + def __init__( + self, + bucket_bytes_cap: int, + backend_compile_fn: CompilerFn, + first_bucket_cap: Optional[int] = None, + ) -> None: + if first_bucket_cap is not None: + self.first_bucket_cap = first_bucket_cap + elif torch.distributed.is_available(): + # this constant comes from C10D lib which is not always built + self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES + else: + self.first_bucket_cap = bucket_bytes_cap + + self.bucket_bytes_cap = bucket_bytes_cap + assert self.first_bucket_cap <= self.bucket_bytes_cap, ( + "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP" + ) + + self.backend_compile_fn = backend_compile_fn + + def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool: + return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored + + def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None: + bucket.size += param.untyped_storage().nbytes() + bucket.params.append(name) + bucket.param_ids.append(id(param)) + + def add_module_params_to_bucket( + self, + mod: torch.nn.Module, + bucket: Bucket, + processed_modules: set[torch.nn.Module], + prefix: str, + ) -> None: + processed_modules.add(mod) + for name, param in mod.named_parameters(): + if param.requires_grad and not self._ignore_parameter(param): + self.add_param(bucket, param, f"{prefix}_{name}") + + def add_param_args(self, bucket: Bucket, node: fx.Node) -> None: + for arg in node.args: + if not isinstance(arg, torch.fx.node.Node): + continue + if arg.op != "placeholder": + continue + param = arg.meta["example_value"] + if ( + isinstance(param, torch.nn.Parameter) + and param.requires_grad + and not self._ignore_parameter(param) + ): + self.add_param(bucket, param, str(arg.target)) + + def compile_fn( + self, gm: fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> CompiledFn: + """ + Implements graph splitting, first determining a set of of buckets by counting + parameter sizes in reverse graph order, then invoking the user/backend compiler + to compile each subgraph. Finally, stiches compiled graphs into one graphmodule + and returns its callable. + """ + # 1: compute the partition map according to DDP bucket logic + buckets = [Bucket()] # (size, param_names) + processed_modules: set[torch.nn.Module] = set() + for node in reversed(gm.graph.nodes): + if node.op in ("output", "placeholder"): + continue + + if ( + buckets[0].size >= self.bucket_bytes_cap + or len(buckets) == 1 + and buckets[0].size >= self.first_bucket_cap + ): + if bucket_has_external_output(buckets[0]): + buckets.insert(0, Bucket()) + else: + # continue building this bucket past the point of filling its parameter capacity, + # to increase chances it contains at least one node that is either a global output or + # passed as input to a subsequent graph + + if buckets[0].opcount_increased_to_capture_external_output == 0: + buckets[0].paramsize_before_opcount_increase = buckets[0].size + buckets[0].opcount_increased_to_capture_external_output += 1 + + if node.op == "call_function": + self.add_param_args(buckets[0], node) + + elif node.op == "call_module": + target_mod = gm.get_submodule(node.target) + if target_mod not in processed_modules: + self.add_module_params_to_bucket( + target_mod, buckets[0], processed_modules, node.target + ) + elif node.op == "call_method": + if isinstance(node.args[0].target, str): + target_mod = None + try: + target_mod = gm.get_submodule(node.args[0].target) + except AttributeError: + pass + if target_mod is not None and target_mod not in processed_modules: + self.add_module_params_to_bucket( + target_mod, buckets[0], processed_modules, node.target + ) + # This handles situations like tmp = torch.mm(x, self.weight.t()) + # t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None + # tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None + self.add_param_args(buckets[0], node) + + elif node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if ( + isinstance(maybe_param, torch.nn.Parameter) + and maybe_param.requires_grad + and not self._ignore_parameter(maybe_param) + ): + self.add_param(buckets[0], maybe_param, node.target) + + # All nodes have to be mapped to a bucket, even if they don't have their own params + # Ignored params still end up in buckets, we just don't count them towards the capacity + buckets[0].nodes.append(node) + + if len(buckets) > 1 and buckets[0].size == 0: + # we collected a small preamble graph with ops that don't include parameters, fuse it back + buckets[1].nodes.extend(buckets[0].nodes) + assert len(buckets[0].params) == 0, "Params should be empty if size is 0" + del buckets[0] + + # stash buckets for testing/debugging purposes + self.buckets = buckets + pretty_print_buckets(buckets, self.bucket_bytes_cap) + + if len(buckets) == 1: + # bypass split/fuse logic if there is only one bucket + return self.backend_compile_fn(gm, example_inputs) + + # 2: partition the graphmodule according to bucket capacity + partition_map = {} + for idx, b in enumerate(buckets): + for node in b.nodes: + partition_map[node] = idx + + split_gm = fx.passes.split_module.split_module( + gm, + None, # type: ignore[arg-type] + lambda node: partition_map[node], + ) + + # See note [Assumption on Dynamo Metadata] + propagate_dynamo_source(gm, split_gm) + propagate_metadata(gm, split_gm) + + debug_str = ( + f"\n---orig graph---\n{gm.graph}\n" + + f"\n---split graph---\n{split_gm.graph}\n" + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # only print the submod graphs, not their children + debug_str += f"\n---{name} graph---\n{module.graph}\n" + debug_str += "\n---------------\n" + ddp_graph_log.debug(debug_str) + + trace_structured( + "optimize_ddp_split_graph", + payload_fn=lambda: split_gm.print_readable(print_output=False), + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + trace_structured( + "optimize_ddp_split_child", + lambda: {"name": name}, + payload_fn=lambda: module.print_readable(print_output=False), + ) + + fake_mode = detect_fake_mode(example_inputs) + if fake_mode is None: + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + + submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode) + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + submod_compiler.run(*example_inputs) + split_gm.recompile() + + ddp_graph_log.debug( + "\n---final graph---\n%s\n---------------\n", split_gm.graph + ) + return split_gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py new file mode 100644 index 0000000000000000000000000000000000000000..93490e64f4ae2044d0c641f8171e733ed7a8e141 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py @@ -0,0 +1,39 @@ +# This backend is maintained by ONNX team. To direct issues +# to the right people, please tag related GitHub issues with `module: onnx`. +# +# Maintainers' Github IDs: wschin, xadupre +# from torch.onnx._internal.onnxruntime import ( +# is_onnxrt_backend_supported, +# torch_compile_backend, +# ) + +# from .registry import register_backend + +""" +Placeholder for onnxruntime backend for dynamo +""" + +# def has_onnxruntime(): +# # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() +# return is_onnxrt_backend_supported() + + +# if is_onnxrt_backend_supported(): +# register_backend(name="onnxrt", compiler_fn=torch_compile_backend) +# else: + +# def information_displaying_backend(*args, **kwargs): +# raise ImportError( +# "onnxrt is not registered as a backend. " +# "Please make sure all dependencies such as " +# "numpy, onnx, onnxscript, and onnxruntime-training are installed. " +# "Suggested procedure to fix dependency problem:\n" +# " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" +# " (2) Open a new python terminal.\n" +# " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" +# " (4) If it returns `True`, then you can use `onnxrt` backend.\n" +# " (5) If it returns `False`, please execute the package importing section in " +# "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." +# ) + +# register_backend(name="onnxrt", compiler_fn=information_displaying_backend) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1469ca478a38647f91b95f1eed8b2a0e6408dd66 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py @@ -0,0 +1,179 @@ +""" +This module implements TorchDynamo's backend registry system for managing compiler backends. + +The registry provides a centralized way to register, discover and manage different compiler +backends that can be used with torch.compile(). It handles: + +- Backend registration and discovery through decorators and entry points +- Lazy loading of backend implementations +- Lookup and validation of backend names +- Categorization of backends using tags (debug, experimental, etc.) + +Key components: +- CompilerFn: Type for backend compiler functions that transform FX graphs +- _BACKENDS: Registry mapping backend names to entry points +- _COMPILER_FNS: Registry mapping backend names to loaded compiler functions + +Example usage: + @register_backend + def my_compiler(fx_graph, example_inputs): + # Transform FX graph into optimized implementation + return compiled_fn + + # Use registered backend + torch.compile(model, backend="my_compiler") + +The registry also supports discovering backends through setuptools entry points +in the "torch_dynamo_backends" group. Example: +``` +setup.py +--- +from setuptools import setup + +setup( + name='my_torch_backend', + version='0.1', + packages=['my_torch_backend'], + entry_points={ + 'torch_dynamo_backends': [ + # name = path to entry point of backend implementation + 'my_compiler = my_torch_backend.compiler:my_compiler_function', + ], + }, +) +``` +``` +my_torch_backend/compiler.py +--- +def my_compiler_function(fx_graph, example_inputs): + # Transform FX graph into optimized implementation + return compiled_fn +``` +Using `my_compiler` backend: +``` +import torch + +model = ... # Your PyTorch model +optimized_model = torch.compile(model, backend="my_compiler") +``` +""" + +import functools +import logging +from collections.abc import Callable, Sequence +from importlib.metadata import EntryPoint +from typing import Any, Optional, Protocol, Union + +import torch +from torch import fx + + +log = logging.getLogger(__name__) + + +class CompiledFn(Protocol): + def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ... + + +CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn] + +_BACKENDS: dict[str, Optional[EntryPoint]] = {} +_COMPILER_FNS: dict[str, CompilerFn] = {} + + +def register_backend( + compiler_fn: Optional[CompilerFn] = None, + name: Optional[str] = None, + tags: Sequence[str] = (), +) -> Callable[..., Any]: + """ + Decorator to add a given compiler to the registry to allow calling + `torch.compile` with string shorthand. Note: for projects not + imported by default, it might be easier to pass a function directly + as a backend and not use a string. + + Args: + compiler_fn: Callable taking a FX graph and fake tensor inputs + name: Optional name, defaults to `compiler_fn.__name__` + tags: Optional set of string tags to categorize backend with + """ + if compiler_fn is None: + # @register_backend(name="") syntax + return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value] + assert callable(compiler_fn) + name = name or compiler_fn.__name__ + assert name not in _COMPILER_FNS, f"duplicate name: {name}" + if compiler_fn not in _BACKENDS: + _BACKENDS[name] = None + _COMPILER_FNS[name] = compiler_fn + compiler_fn._tags = tuple(tags) # type: ignore[attr-defined] + return compiler_fn + + +register_debug_backend = functools.partial(register_backend, tags=("debug",)) +register_experimental_backend = functools.partial( + register_backend, tags=("experimental",) +) + + +def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn: + """Expand backend strings to functions""" + if isinstance(compiler_fn, str): + if compiler_fn not in _BACKENDS: + _lazy_import() + if compiler_fn not in _BACKENDS: + from ..exc import InvalidBackend + + raise InvalidBackend(name=compiler_fn) + + if compiler_fn not in _COMPILER_FNS: + entry_point = _BACKENDS[compiler_fn] + if entry_point is not None: + register_backend(compiler_fn=entry_point.load(), name=compiler_fn) + compiler_fn = _COMPILER_FNS[compiler_fn] + return compiler_fn + + +# NOTE: can't type this due to public api mismatch; follow up with dev team +def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: ignore[no-untyped-def] + """ + Return valid strings that can be passed to: + + torch.compile(..., backend="name") + """ + _lazy_import() + exclude_tags_set = set(exclude_tags or ()) + + backends = [ + name + for name in _BACKENDS + if name not in _COMPILER_FNS + or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined] + ] + return sorted(backends) + + +@functools.cache +def _lazy_import() -> None: + from .. import backends + from ..utils import import_submodule + + import_submodule(backends) + + from ..repro.after_dynamo import dynamo_minifier_backend + + assert dynamo_minifier_backend is not None + + _discover_entrypoint_backends() + + +@functools.cache +def _discover_entrypoint_backends() -> None: + # importing here so it will pick up the mocked version in test_backends.py + from importlib.metadata import entry_points + + group_name = "torch_dynamo_backends" + eps = entry_points(group=group_name) + eps_dict = {name: eps[name] for name in eps.names} + for backend_name in eps_dict: + _BACKENDS[backend_name] = eps_dict[backend_name] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py new file mode 100644 index 0000000000000000000000000000000000000000..60d7b87bd0876a85702c07db7c82cd804ee608d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py @@ -0,0 +1,55 @@ +import logging +from collections.abc import Callable +from typing import Any + +import torch +from functorch.compile import make_boxed_func +from torch import fx + +from ..backends.common import aot_autograd +from .registry import CompiledFn, register_backend, register_experimental_backend + + +log = logging.getLogger(__name__) + + +@register_experimental_backend +def openxla_eval( + model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] +) -> CompiledFn: + return xla_backend_helper(model, fake_tensor_inputs, boxed=False) + + +def openxla_eval_boxed( + model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] +) -> Callable[..., Any]: + return xla_backend_helper(model, fake_tensor_inputs, boxed=True) + + +def xla_backend_helper( + model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], boxed: bool = False +) -> Callable[..., Any]: + try: + import torch_xla.core.dynamo_bridge as bridge + except ImportError as e: + raise ImportError( + "Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla" + ) from e + + compiled_graph = None + + def fwd(*args: torch.Tensor) -> Any: + nonlocal model + nonlocal compiled_graph + if compiled_graph is None: + compiled_graph = bridge.extract_compiled_graph(model, args) + del model + return compiled_graph(*args) + + return make_boxed_func(fwd) if boxed else fwd + + +openxla = aot_autograd( + fw_compiler=openxla_eval_boxed, +) +register_backend(name="openxla", compiler_fn=openxla) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..136f42f3a5ba9dc3d163287e948b87f426715253 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a320103c6e2878e35ca06305794c4b1939fde08c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4594034f9a8c2ef547c1576a349580c661afe847 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4811a0ba0d25f3ebfeeca051b2da1ecc6c0ff71a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6f606e39e9549586f4151d1cde170d65d4623a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6345304ce17037baeac20cb2950a8e4d0d52eec Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f14dca4803f798e4258fe08607c27f584970bdd2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e95fbab51525c025ffd84139f8c51c58b6e9a217 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79604ac27272fa42527f30c42b6f9baa01f3330f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40ba412b9b921537ec0d65e971f844e7eae2475a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/case.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/case.py new file mode 100644 index 0000000000000000000000000000000000000000..048a71cd6c16a205bbe9d7f845369b93f6a02f2e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/case.py @@ -0,0 +1,175 @@ +# mypy: allow-untyped-defs +import inspect +import re +import string +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional +from types import ModuleType + +import torch + +_TAGS: dict[str, dict[str, Any]] = { + "torch": { + "cond": {}, + "dynamic-shape": {}, + "escape-hatch": {}, + "map": {}, + "dynamic-value": {}, + "operator": {}, + "mutation": {}, + }, + "python": { + "assert": {}, + "builtin": {}, + "closure": {}, + "context-manager": {}, + "control-flow": {}, + "data-structure": {}, + "standard-library": {}, + "object-model": {}, + }, +} + + +class SupportLevel(Enum): + """ + Indicates at what stage the feature + used in the example is handled in export. + """ + + SUPPORTED = 1 + NOT_SUPPORTED_YET = 0 + + +ArgsType = tuple[Any, ...] + + +def check_inputs_type(args, kwargs): + if not isinstance(args, tuple): + raise ValueError( + f"Expecting args type to be a tuple, got: {type(args)}" + ) + if not isinstance(kwargs, dict): + raise ValueError( + f"Expecting kwargs type to be a dict, got: {type(kwargs)}" + ) + for key in kwargs: + if not isinstance(key, str): + raise ValueError( + f"Expecting kwargs keys to be a string, got: {type(key)}" + ) + +def _validate_tag(tag: str): + parts = tag.split(".") + t = _TAGS + for part in parts: + assert set(part) <= set( + string.ascii_lowercase + "-" + ), f"Tag contains invalid characters: {part}" + if part in t: + t = t[part] + else: + raise ValueError(f"Tag {tag} is not found in registered tags.") + + +@dataclass(frozen=True) +class ExportCase: + example_args: ArgsType + description: str # A description of the use case. + model: torch.nn.Module + name: str + example_kwargs: dict[str, Any] = field(default_factory=dict) + extra_args: Optional[ArgsType] = None # For testing graph generalization. + # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) + tags: set[str] = field(default_factory=set) + support_level: SupportLevel = SupportLevel.SUPPORTED + dynamic_shapes: Optional[dict[str, Any]] = None + + def __post_init__(self): + check_inputs_type(self.example_args, self.example_kwargs) + if self.extra_args is not None: + check_inputs_type(self.extra_args, {}) + + for tag in self.tags: + _validate_tag(tag) + + if not isinstance(self.description, str) or len(self.description) == 0: + raise ValueError(f'Invalid description: "{self.description}"') + + +_EXAMPLE_CASES: dict[str, ExportCase] = {} +_MODULES: set[ModuleType] = set() +_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {} +_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {} + + +def register_db_case(case: ExportCase) -> None: + """ + Registers a user provided ExportCase into example bank. + """ + if case.name in _EXAMPLE_CASES: + if case.name not in _EXAMPLE_CONFLICT_CASES: + _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]] + _EXAMPLE_CONFLICT_CASES[case.name].append(case) + return + + _EXAMPLE_CASES[case.name] = case + + +def to_snake_case(name): + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +def _make_export_case(m, name, configs): + if not isinstance(m, torch.nn.Module): + raise TypeError("Export case class should be a torch.nn.Module.") + + if "description" not in configs: + # Fallback to docstring if description is missing. + assert ( + m.__doc__ is not None + ), f"Could not find description or docstring for export case: {m}" + configs = {**configs, "description": m.__doc__} + # pyrefly: ignore [bad-argument-type] + return ExportCase(**{**configs, "model": m, "name": name}) + + +def export_case(**kwargs): + """ + Decorator for registering a user provided case into example bank. + """ + + def wrapper(m): + configs = kwargs + module = inspect.getmodule(m) + if module in _MODULES: + raise RuntimeError("export_case should only be used once per example file.") + + assert module is not None + _MODULES.add(module) + module_name = module.__name__.split(".")[-1] + case = _make_export_case(m, module_name, configs) + register_db_case(case) + return case + + return wrapper + + +def export_rewrite_case(**kwargs): + def wrapper(m): + configs = kwargs + + parent = configs.pop("parent") + assert isinstance(parent, ExportCase) + key = parent.name + if key not in _EXAMPLE_REWRITE_CASES: + _EXAMPLE_REWRITE_CASES[key] = [] + + configs["example_args"] = parent.example_args + case = _make_export_case(m, to_snake_case(m.__name__), configs) + _EXAMPLE_REWRITE_CASES[key].append(case) + return case + + return wrapper diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/gen_example.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/gen_example.py new file mode 100644 index 0000000000000000000000000000000000000000..8e44cade322bdde858c5dd05ac116cef47202a33 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/gen_example.py @@ -0,0 +1,21 @@ +import os +import sys + +import torch._export.db.examples as examples + +TEMPLATE = '''import torch + +def {case_name}(x): + """ + """ + + return +''' + +if __name__ == "__main__": + assert len(sys.argv) == 2 + root_dir = examples.__name__.replace(".", "/") + assert os.path.exists(root_dir) + with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f: + print("Writing to", f.name, "...") + f.write(TEMPLATE.format(case_name=sys.argv[1])) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/logging.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..9d18a5c0ea08e86095a44240657034ffff3135d8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/logging.py @@ -0,0 +1,47 @@ +from typing import Optional + +def exportdb_error_message(case_name: str) -> str: + from .examples import all_examples + from torch._utils_internal import log_export_usage + + ALL_EXAMPLES = all_examples() + # Detect whether case_name is really registered in exportdb. + if case_name in ALL_EXAMPLES: + url_case_name = case_name.replace("_", "-") + return f"See {case_name} in exportdb for unsupported case. \ + https://pytorch.org/docs/main/generated/exportdb/index.html#{url_case_name}" + else: + log_export_usage( + event="export.error.casenotregistered", + message=case_name, + ) + return f"{case_name} is unsupported." + + +def get_class_if_classified_error(e: Exception) -> Optional[str]: + """ + Returns a string case name if the export error e is classified. + Returns None otherwise. + """ + + from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError + + ALWAYS_CLASSIFIED = "always_classified" + DEFAULT_CLASS_SIGIL = "case_name" + + # add error types that should be classified, along with any attribute name + # whose presence acts like a sigil to further distinguish which errors of + # that type should be classified. If the attribute name is None, then the + # error type is always classified. + _ALLOW_LIST = { + Unsupported: DEFAULT_CLASS_SIGIL, + UserError: DEFAULT_CLASS_SIGIL, + TorchRuntimeError: None, + } + if type(e) in _ALLOW_LIST: + # pyrefly: ignore [index-error] + attr_name = _ALLOW_LIST[type(e)] + if attr_name is None: + return ALWAYS_CLASSIFIED + return getattr(e, attr_name, None) + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..9874dc1520fdbd6f4adc061dd7bccee031710797 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py @@ -0,0 +1,32 @@ +from typing import Any + + +NodeMetadataValue = Any + + +PROTECTED_KEYS: set[str] = { + "val", + "stack_trace", + "nn_module_stack", + "debug_handle", + "tensor_meta", +} + + +class NodeMetadata: + def __init__(self, data: dict[str, Any]) -> None: + self.data: dict[str, Any] = data.copy() + + def __getitem__(self, key: str) -> NodeMetadataValue: + return self.data[key] + + def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: + if key in PROTECTED_KEYS: + raise RuntimeError(f"Could not override node key: {key}") + self.data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self.data + + def copy(self) -> "NodeMetadata": + return NodeMetadata(self.data.copy()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py new file mode 100644 index 0000000000000000000000000000000000000000..40613c1283228bb5500a93c5b4ca80d6a448ce6d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py @@ -0,0 +1,45 @@ +# pyre-strict +from collections.abc import Iterable, Iterator +from typing import Generic, TypeVar, Union + +import torch + + +_T = TypeVar("_T") + + +class ProxyValue(Generic[_T]): + # pyre-ignore + def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]): + # pyre-ignore + self.data = data + self.proxy_or_node = proxy + + @property + def node(self) -> torch.fx.Node: + if isinstance(self.proxy_or_node, torch.fx.Node): + return self.proxy_or_node + assert isinstance(self.proxy_or_node, torch.fx.Proxy) + return self.proxy_or_node.node + + @property + def proxy(self) -> torch.fx.Proxy: + if not isinstance(self.proxy_or_node, torch.fx.Proxy): + raise RuntimeError( + f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" + ) + return self.proxy_or_node + + def to_tensor(self) -> torch.Tensor: + assert isinstance(self.data, torch.Tensor) + return self.data + + def is_tensor(self) -> bool: + return isinstance(self.data, torch.Tensor) + + # pyre-ignore + def __iter__(self) -> Iterator[_T]: + yield from self.data + + def __bool__(self) -> bool: + return bool(self.data) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9ce2ac03c23600c86ff02e38a2a4bfeefef9e2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__init__.py @@ -0,0 +1 @@ +from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d475b1af35ffd1de1057d009ddcfb827b50da59e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..736d93214ad289f7109fe53001899e2794352ab4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e29cb4acf3ab262bc4c01af41bc111d9ab6f9fee Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b46d3117354c0428ba43d467dff1dff381224ce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35f48e13df68fbd87f32f62c868ba7f09d5016cd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30b0aa6f7ffd8e3ad69555ee97b04253c050d8bc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3d441ad04ab153fe26c8372497c48f6251b9b86 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af080b1183cefa44b0a48d6b7299c226fba7873a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1589add5e152dac6d2a9b5ae638c572df77b7530 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716bb3b53a8fd59cad8bb51fad6d60b6ca149ccc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39af7ba94a4496ccc40c9aae3336d05b32a37b4b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3547a5f73c77485f7cd63f89ecbd13ef8c642e98 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.graph_module import GraphModule + + +_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook" + + +def _node_metadata_hook( + node: torch.fx.Node, + metadata: Optional[dict[str, Any]] = None, + fake_mode: Optional[FakeTensorMode] = None, +) -> None: + """ + Hook for adding the appropriate metadata to nodes that are created during a + pass using graph.create_node. An example of how to use it: + + ``` + with _set_node_metadata_hook(gm, + functools.partial(_node_metadata_hook, metadata={"stack_trace": "file"}) + ): + pass(gm) + ``` + + This hook should not work for all generic cases -- specifically it assumes + that nodes being added are only call_function nodes, and copies over the + first argument node's nn_module_stack. + """ + # pyrefly: ignore [bad-assignment] + fake_mode = fake_mode or contextlib.nullcontext() + + assert node.op == "call_function" and callable(node.target), ( + f"node: {node}, target: {node.target}" + ) + + if ( + isinstance(node.target, torch._ops.OpOverload) + and len(node.target._schema.returns) == 0 + ): + node.meta["val"] = None + else: + fake_args, fake_kwargs = pytree.tree_map_only( + torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs) + ) + # pyrefly: ignore [bad-context-manager] + with fake_mode, enable_python_dispatcher(): + fake_res = node.target(*fake_args, **fake_kwargs) + node.meta["val"] = fake_res + + if metadata is not None: + for k, v in metadata.items(): + node.meta[k] = v + + # Copy over metadata from argument nodes + arg_meta = [ + arg.meta + for arg in pytree.tree_flatten((node.args, node.kwargs))[0] + if isinstance(arg, torch.fx.Node) + ] + if len(arg_meta) == 0: + return + arg_meta = arg_meta[0] + + node.meta["nn_module_stack"] = node.meta.get( + "nn_module_stack", + arg_meta.get( + "nn_module_stack", + { + _EMPTY_NN_MODULE_STACK_KEY: ( + _EMPTY_NN_MODULE_STACK_KEY, + _EMPTY_NN_MODULE_STACK_KEY, + ) + }, + ), + ) + + node.meta["torch_fn"] = node.meta.get( + "torch_fn", + ( + f"{node.target.__name__}_0", + # pyrefly: ignore [missing-attribute] + f"{node.target.__class__.__name__}.{node.target.__name__}", + ), + ) + + +@contextlib.contextmanager +def _set_node_metadata_hook(gm: torch.fx.GraphModule, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "node_metadata_hook must be a callable." + + # Add the hook to all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._register_create_node_hook(f) + try: + yield + finally: + # Restore hook for all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._unregister_create_node_hook(f) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..345401e9f76e5e82d462f3a5c56a30bb3e1f5e8a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +import math +import operator +import traceback +from functools import partial +from typing import NamedTuple, TYPE_CHECKING + +import sympy + +import torch +import torch.fx +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + + +if TYPE_CHECKING: + from collections.abc import Callable + + +__all__ = ["InputDim"] + + +class InputDim(NamedTuple): + input_name: str + dim: int + + +def _convert_to_int(val): + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return math.inf + if val in (-sympy.oo, -int_oo): + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError("Export constraints cannot be non-integer expressions") + + +def _convert_range_to_int(range: ValueRanges): + assert isinstance(range, ValueRanges) + min_val = _convert_to_int(range.lower) + max_val = _convert_to_int(range.upper) + return min_val, max_val + + +class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): + def __init__( + self, + range_constraints: dict[sympy.Symbol, ValueRanges], + ): + super().__init__() + self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints + self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set() + self.counter = 0 + + def _assert_range_constraint(self, node, lower, upper, assert_msg): + last_node = node + if lower > -math.inf: + last_node = self._insert_assert_async( + last_node, operator.ge, node, lower, assert_msg + ) + + if upper < math.inf: + last_node = self._insert_assert_async( + last_node, operator.le, node, upper, assert_msg + ) + + def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): + """ + Inserts assert_async call_function nodes in the graph. This function is + called **during** the interpreter-based pass. + """ + self.counter += 1 + graph = last_node.graph + with graph.inserting_after(last_node): + cmp = graph.call_function(op, (lower, upper), {}) + with graph.inserting_after(cmp): + cmp_tensor = graph.call_function( + torch.ops.aten.scalar_tensor.default, (cmp,), {} + ) + with graph.inserting_after(cmp_tensor): + assert_async = graph.call_function( + torch.ops.aten._assert_async.msg, + (cmp_tensor, assert_msg), + {}, + ) + return assert_async + + def call(self, graph_module) -> PassResult: + self.existing_inline_assertions = _get_existing_inline_assertions( + graph_module, self.range_constraints + ) + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if "val" not in node.meta: + continue + + val = node.meta["val"] + # In general, we may have to deal the case such as: ret[1].shape[0]. + # We need first find out what symbols require assertion, then we need to follow the path + # from ret to the symbol, construct the proxies along the way and construct the messages + # piece-wise at the same time. + # + # We use post-order traversal to collect all the proxies callbacks needed, construct + # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. + # We need the callbacks because, in order to call the function to create a proxy for shape[0], we + # need the proxy for shape, which further requires the proxy for ret[1], etc. + + def add_assertions(val): + call_backs: list[Callable] = [] + messages: list[str] = [] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + symbol = val.node.expr + if symbol in self.existing_inline_assertions: + return call_backs, messages + if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols( + symbol + ): + if symbol in self._asserts_generated_unbacked_symbols: + return call_backs, messages + # We only care about unbacked symints for these inline + # constraints, which are prefixed with 'u' + constraint = self.range_constraints[symbol] + min_val, max_val = _convert_range_to_int(constraint) + assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." + call_backs.append( + partial( + self._assert_range_constraint, + lower=min_val, + upper=max_val, + ) + ) + messages.append(assert_msg) + self._asserts_generated_unbacked_symbols.add(symbol) + + elif isinstance(val, torch.Tensor): + for i, sym in enumerate(val.shape): + cbs, msgs = add_assertions(sym) + for cb, msg in zip(cbs, msgs): + + def sym_size_cb(node, assert_msg, dim): + with node.graph.inserting_after(node): + dim_node = module.graph.call_function( + torch.ops.aten.sym_size.int, + (node, dim), + {}, + ) + cb(node=dim_node, assert_msg=assert_msg) + + call_backs.append(partial(sym_size_cb, dim=i)) + messages.append(f".shape[{i}]" + msg) + return call_backs, messages + + callbacks, messages = add_assertions(val) + for cb, msg in zip(callbacks, messages): + cb(node=node, assert_msg=f"{node}" + msg) + + module.recompile() + + # Sometimes this pass would return a wrong graph where we have mismatched + # node names in signature. Before we fix it, let's just skip it. + if ( + self.counter == 0 + and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass + ): + return PassResult(graph_module, False) + + # Populate the stack trace with dummy vals to respect IR + for node in graph_module.graph.nodes: + if not node.meta.get("stack_trace", None) and node.op not in [ + "placeholder", + "output", + ]: + node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) + return PassResult(graph_module, True) + + +def _get_existing_inline_assertions( + graph_module: torch.fx.GraphModule, + range_constraints: dict[sympy.Symbol, ValueRanges], +) -> dict[sympy.Symbol, ValueRanges]: + existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {} + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + # Find all the existing inline assertions. They will look something like: + # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {}) + # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {}) + # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {}) + for node in module.graph.nodes: + if node.target != torch.ops.aten._assert_scalar.default: + continue + + compare_arg = node.args[0] + if not ( + isinstance(compare_arg, torch.fx.Node) + and compare_arg.op == "call_function" + and compare_arg.target in (operator.le, operator.ge) + and len(compare_arg.args) == 2 + ): + continue + + compare_op = compare_arg.target + lhs, rhs = compare_arg.args + + def maybe_get_symint(x): + if ( + isinstance(x, torch.fx.Node) + and "val" in x.meta + and isinstance(x.meta["val"], torch.SymInt) + ): + return x.meta["val"].node.expr + return x + + lhs = maybe_get_symint(lhs) + rhs = maybe_get_symint(rhs) + + if compare_op is operator.ge: + lhs, rhs = rhs, lhs + + if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int): + symint = lhs + scalar = rhs + elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int): + symint = rhs + scalar = lhs + else: + continue + + if symint not in range_constraints: + raise RuntimeError( + f"Unable to find symint {symint} in {range_constraints}" + ) + + previous_range = existing_inline_assertions.get( + symint, ValueRanges(-math.inf, math.inf) + ) + + if symint is lhs: + bounds = ValueRanges(-math.inf, scalar) + else: + bounds = ValueRanges(scalar, math.inf) + existing_inline_assertions[symint] = previous_range & bounds + + return existing_inline_assertions diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a82564886889deabfc758d61e32289ab7843a2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py @@ -0,0 +1,146 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import operator +from typing import TYPE_CHECKING + +import torch +from torch.export.exported_program import ConstantArgument, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +if TYPE_CHECKING: + from torch.export.exported_program import ModuleCallSignature + from torch.export.graph_signature import ExportGraphSignature + + +__all__ = ["CollectTracepointsPass"] + + +class CollectTracepointsPass(PassBase): + """ + Performs constant folding and constant propagation. + """ + + def __init__( + self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature + ) -> None: + super().__init__() + self.specs = specs + self.sig = sig + + def call(self, gm: torch.fx.GraphModule) -> PassResult | None: + def get_arg_spec(arg) -> TensorArgument | ConstantArgument: + if isinstance(arg, torch.fx.Node): + if isinstance(arg.meta.get("val"), torch.Tensor): + return TensorArgument(name=arg.name) + else: + raise AssertionError( + "Symint input is not implemented yet for submodule call signature." + ) + else: + return ConstantArgument(name="", value=arg) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + nn_module_stack = None + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target is torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_outputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_inputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + nn_module_stack = None + for node in reversed(module.graph.nodes): + if node.op != "call_function": + continue + if node.target is torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_inputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_outputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + + def copy_sig(sig) -> ModuleCallSignature: + from torch.export.exported_program import ModuleCallSignature + + return ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=sig.in_spec, + out_spec=sig.out_spec, + forward_arg_names=None, + ) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target is torch.ops.higher_order._export_tracepoint: + # There's some subtlety worth noting. Here fqn corresponds to + # the call name, whereas path corresponds to the module name. + # They are not necessarily the same! When a submodule is shared + # through different aliases, there are as many _export_tracepoint + # markers as there are aliases, since the shared submodule is + # wrapped once for each alias. + path = node.kwargs["path"] + fqn, _ = next(reversed(node.meta["nn_module_stack"].values())) + + module_key = next(reversed(node.meta["nn_module_stack"])) + if "@" in module_key: + suffix = module_key.split("@")[-1] + path = f"{path}@{suffix}" + + call_fqn = f"{fqn}@{suffix}" + if call_fqn not in self.specs: + self.specs[call_fqn] = copy_sig(self.specs[fqn]) + fqn = call_fqn + + kind = node.kwargs["kind"] + for i, arg in enumerate(node.args): + # We only update the signature of the alias used to call + # the submodule. Otherwise the signatures of all aliases + # would get conflated; the inputs/outputs of every call + # would be recorded in every other call as well. + if fqn == path: + if kind == "module_call_inputs": + self.specs[path].inputs.append(get_arg_spec(arg)) + elif kind == "module_call_outputs": + self.specs[path].outputs.append(get_arg_spec(arg)) + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + if isinstance(arg, torch.fx.Node): + for user in node.users: + assert user.op == "call_function" + assert user.target is operator.getitem + assert isinstance(user.args[1], int) + if user.args[1] == i: + user.replace_all_uses_with(arg) + self.sig.replace_all_uses(user.name, arg.name) + break + users = list(node.users) + for user in users: + assert len(user.users) == 0 + gm.graph.erase_node(user) + gm.graph.erase_node(node) + return PassResult(gm, True) + + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..58534856422c73b20fc85877c8d13ea88532aa45 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py @@ -0,0 +1,304 @@ +# mypy: allow-untyped-defs +import collections +from collections import defaultdict +from collections.abc import Callable +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm: torch.fx.GraphModule, + skip_constructors: bool = False, + ): + super().__init__(gm) + self.node_replacements: dict[torch.fx.Node, Any] = {} + self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.Node) -> bool: + if ( + node.target is torch.ops.prims.convert_element_type.default + and node.args[0].op == "get_attr" # type: ignore[union-attr] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.pt2e_quant.dequantize_affine, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr] + + for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) is type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target is aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): # type: ignore[override] + env = {} + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + env[n] = self.unknown_value + return super().run(initial_env=env) + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +): + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + # Get all attr users by looking up the graph instead from node.users, because in this case + # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. + + # opcode name target args kwargs + # ------------- ------------------- ---------------- --------------------------- -------- + # placeholder arg0_1 arg0 () {} + # get_attr _tensor_constant0 state () {} + # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} + # get_attr _tensor_constant0_1 state () {} + # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} + # output output output ([add],) {} + + get_attr_node_users = defaultdict(list) + for node in gm.graph.nodes: + if node.op == "get_attr": + get_attr_node_users[node.target].extend(node.users.keys()) + for node in gm.graph.find_nodes(op="get_attr"): + if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag(gm: torch.fx.GraphModule) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..45dd734c72959cd23c00d88e18dbcf80b8cd3227 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -0,0 +1,99 @@ +import copy +from typing import Optional + +import torch +from torch._export.pass_base import ( + _ExportPassBaseDeprecatedDoNotUse, + Argument, + PassResult, +) +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._ops import OpOverload + + +aten = torch.ops.aten + +_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = { + aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default, + aten._assert_async.msg: aten._functional_assert_async.msg, +} + + +class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Functionalize ops with side effect in graph module by replacing the op with + functional version of it. A new dependency token (`dep_token`) will be + created and propagated through functional ops to output. + For example: + ``` + def f(x): + sym_constrain_range(x.shape[0], min=1, max=3) + return x.add(3) + ``` + Will be transformed to: + ``` + def f(x): + dep_token0 = _make_dep_token() + dep_token1 = _functional_sym_constrain_range( + x.shape[0], min=1, max=3, dep_token=dep_token0 + ) + + return x.add(3), dep_token1 + ``` + """ + + def __init__(self) -> None: + super().__init__() + self._dep_token: Optional[ProxyValue] = None + self._next_dep_token_index: Optional[int] = None + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Early return if no non-functional assertions. + if not any( + n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS + for n in graph_module.graph.nodes + ): + return PassResult(graph_module=graph_module, modified=False) + + gm = copy.deepcopy(graph_module) + self._dep_token = None + self._next_dep_token_index = None + return super().call(gm) + + def call_operator( + self, + op: OpOverload, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: + return super().call_operator(op, args, kwargs, meta) + + if self._dep_token is None: + self._dep_token = super().call_operator( + aten._make_dep_token, + args=(), + kwargs={}, + meta=self._create_dummy_node_metadata(), + ) + self._dep_token.node.name = "dep_token0" + self._next_dep_token_index = 1 + + self._dep_token = super().call_operator( + _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op], + args=args, + kwargs={**kwargs, "dep_token": self._dep_token}, + meta=meta, + ) + assert self._next_dep_token_index is not None + self._dep_token.node.name = f"dep_token{self._next_dep_token_index}" + self._next_dep_token_index += 1 + + return self._dep_token + + def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: + assert self._dep_token is not None + + return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1e5fb6a9d7fb47ed6d2a9164313b04bbab37c6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py @@ -0,0 +1,80 @@ +import functools +from collections import defaultdict + +import torch +from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, +) +from torch._library.fake_profile import OpProfile, TensorMetadata + + +def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> None: + """ + This is used by draft_export to insert guards in front of calls to custom + operators which have a generated fake kernel. + """ + for node in gm.graph.nodes: + if node.op == "call_function" and str(node.target) in ops_to_guard: + with ( + _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + metadata={"stack_trace": node.meta.get("stack_trace")}, + ), + ), + gm.graph.inserting_before(node), + ): + for arg in (*node.args, *node.kwargs.values()): + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta.get("val"), torch.Tensor + ): + val = arg.meta["val"] + gm.graph.call_function( + torch.ops.aten._assert_tensor_metadata.default, + args=(arg,), + kwargs={ + "dtype": val.dtype, + "device": val.device, + "layout": val.layout, + }, + ) + + gm.recompile() + + +def get_op_profiles( + gm: torch.fx.GraphModule, ops_to_guard: set[str] +) -> dict[str, set[OpProfile]]: + """ + This is used by draft_export to get a list of custom operator profiles so + that we can generate fake kernels. + """ + + def _get_op_profile(node: torch.fx.Node) -> OpProfile: + args_profile = tuple( + TensorMetadata.maybe_from_tensor(arg.meta.get("val")) + if isinstance(arg, torch.fx.Node) + else None + for arg in (*node.args, *node.kwargs.values()) + ) + + out_profile = None + meta = node.meta.get("val") + assert meta is not None + if isinstance(meta, torch.Tensor): + out_profile = TensorMetadata.maybe_from_tensor(meta) + elif isinstance(meta, (list, tuple)): + out_profile = tuple(TensorMetadata.maybe_from_tensor(m) for m in meta) # type: ignore[assignment] + assert out_profile is not None + + return OpProfile(args_profile, out_profile) # type: ignore[arg-type] + + op_profiles: dict[str, set[OpProfile]] = defaultdict(set) + + for node in gm.graph.nodes: + if node.op == "call_function" and str(node.target) in ops_to_guard: + op_profiles[str(node.target)].add(_get_op_profile(node)) + + return op_profiles diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..607989cd919cbb6d4cf59aab3071a9f7c5b5375f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py @@ -0,0 +1,417 @@ +# mypy: allow-untyped-defs +import collections +import logging +from typing import Any, Optional, Union + +import torch +from torch._export.verifier import SpecViolationError +from torch._guards import detect_fake_mode +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_reference_type +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.export.exported_program import ( + ArgumentSpec, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.graph_module import _get_attr + + +log = logging.getLogger(__name__) + + +class ConstantAttrMap(collections.abc.MutableMapping): + """A mapping class that understands how to use module constants (tensors, + ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally, + but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to + the same underlying value (but we guarantee that they will `hash()` to the same value + if that's the case). + """ + + def __init__(self) -> None: + # Underlying dict that we use to implement this mapping. + self._constant_attrs: dict[ + Union[int, torch.Tensor, FakeScriptObject, torch.utils._pytree.TreeSpec], + list[Any], + ] = {} + # Map from the hash(ScriptObject) to the ScriptObject itself. Used for + # APIs like `__iter__` that should look like they're returning the + # original ScriptObjects. + self._script_object_map: dict[int, torch.ScriptObject] = {} + + def __getitem__(self, key: _ConstantAttributeType) -> Any: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject)) + return self._constant_attrs[real_key] + + def __setitem__(self, key: _ConstantAttributeType, value): + # we shouldn't actually call this, should go to add() instead to handle aliasing + raise NotImplementedError( + """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead. +The same key can be mapped to multiple values, for handling constant aliasing.""" + ) + + def add(self, key: _ConstantAttributeType, value: Any) -> None: + if isinstance(key, torch.ScriptObject): + if hash(key) not in self._constant_attrs: + self._constant_attrs[hash(key)] = [] + self._constant_attrs[hash(key)].append(value) + self._script_object_map[hash(key)] = key + elif isinstance(key, (torch.Tensor, FakeScriptObject)): + if key not in self._constant_attrs: + self._constant_attrs[key] = [] + self._constant_attrs[key].append(value) + else: + raise TypeError( + f"Expected key to be a tensor or ScriptObject, got {type(key)}" + ) + + def __delitem__(self, key: _ConstantAttributeType): + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + + del self._constant_attrs[real_key] + + def __iter__(self): + for key in self._constant_attrs: + if isinstance(key, int): + yield self._script_object_map[key] + else: + yield key + + def __len__(self): + return len(self._constant_attrs) + + def __contains__(self, key: object) -> bool: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + return real_key in self._constant_attrs + + +def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str: + # The FQN of the constant tensor in the state dict should + # correspond to the module where the constant tensor was + # originally used. + if len(node.meta["nn_module_stack"]) == 0: + return constant_name + parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0] + if len(parent_fqn) > 0: + return f"{parent_fqn}.{constant_name}" + else: + return constant_name + + +def _get_first_fqn( + const_attrs: ConstantAttrMap, + key: _ConstantAttributeType, +) -> Any: + fqns = const_attrs.get(key) + return fqns[0] if fqns else None + + +def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]: + """ + If there is a tensor constant created while tracing, here is how the graph + looks like: + + %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0] + %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,)) + %detach_ : [num_users=?] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,)) + + To check to see if the tensor constant is being used, we want to traverse to + the detach node to see if it's actually being used. + + This function returns None if this constant is being used, otherwise it returns the + lift_fresh and detach node to be removed later. + """ # noqa: B950 + if len(node.users) > 1: + return None + + lift_fresh_node = next(iter(node.users.keys())) + if not ( + lift_fresh_node.op == "call_function" + and lift_fresh_node.target + in ( + torch.ops.aten.lift_fresh.default, + torch.ops.aten.lift_fresh_copy.default, + ) + ): + return None + + if len(lift_fresh_node.users) > 1: + return None + + # Case 1: lift node is not used anywhere + if len(lift_fresh_node.users) == 0: + return [lift_fresh_node, node] + + detach_node = next(iter(lift_fresh_node.users.keys())) + if not ( + detach_node.op == "call_function" + and detach_node.target + in ( + torch.ops.aten.detach_.default, + torch.ops.aten.detach.default, + ) + ): + return None + + if len(detach_node.users) > 0: + return None + else: + # Case 2: Lift node's child is not used anywhere + return [detach_node, lift_fresh_node, node] + + +def lift_constants_pass( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> dict[str, _ConstantAttributeType]: + """ + Takes a graph module, graph signature, and modifies them inplace to lift any + constants (tensors or custom classes) as inputs to the graph. Returns a + dictionary of names to constants. + + Arguments: + gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift. + graph_signature (ExportGraphSignature): This graph signature will be + mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs. + constant_attrs (ConstantAttr): A mapping from a constant value to its + fully-qualified path in `gm`. This is used to maintain consistent + location of constants between the original module and the exported + version. + + Returns: + A dictionary of fqn => constant value. + """ + all_constants: dict[str, _ConstantAttributeType] = {} + + input_specs = graph_signature.input_specs + num_custom_obj = sum( + input_spec.kind == InputKind.CUSTOM_OBJ for input_spec in input_specs + ) + num_tensor_constants = sum( + input_spec.kind == InputKind.CONSTANT_TENSOR for input_spec in input_specs + ) + + fake_mode = detect_fake_mode( + tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") + ) + + first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes)) + used_target_names = set() + + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(input_nodes) == len(input_specs) + for i, (node, input_spec) in enumerate(zip(input_nodes, input_specs)): + used_target_names.add(input_spec.target) + if input_spec.kind == InputKind.USER_INPUT: + first_user_input = node + first_user_input_loc = i + break + + lifted_objs = ConstantAttrMap() + renamed_targets = {} + for node in list(gm.graph.nodes): + if node.op == "get_attr": + if nodes_to_remove := _unused_constant(node): + # Remove the node if it's not being used + for node_rm in nodes_to_remove: + gm.graph.erase_node(node_rm) + continue + + constant_val = _get_attr(gm, node.target) + # These are not hashable and not gonna be lifted + # so we can skip them earlier + if isinstance(constant_val, torch.fx.GraphModule): + continue + if "LoweredBackendModule" in type(constant_val).__name__: + continue + if "AOTInductorRunnerWrapper" in type(constant_val).__name__: + continue + if isinstance(constant_val, torch.utils._pytree.TreeSpec): + continue + + if constant_val in lifted_objs: + # We already lifted this constant elsewhere. Just rewrite uses + # of this get_attr to point to the already-existing placeholder + # node. + const_placeholder_node = _get_first_fqn(lifted_objs, constant_val) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + renamed_targets[node.name] = const_placeholder_node.name + continue + + # For ScriptObject, Tensor and FakeScriptObject constants: + # First check if the constant was an attribute on some module by + # consulting `constant_attrs` map. If it is, use the fqn that keeps + # its location consistent with the eager module. + # + # If it's not in the `constant_attrs` map, that means it's an inline + # constant (e.g. x + torch.tensor(0)), and thus did not have a + # specific location in the eager module. In that case, just generate + # some name and attach it to the module in which it was used. + if isinstance( + constant_val, (torch.ScriptObject, FakeScriptObject) + ) or is_opaque_reference_type(type(constant_val)): + constant_kind = InputKind.CUSTOM_OBJ + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_custom_{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + while constant_fqn in used_target_names: + num_custom_obj += 1 + constant_name = f"lifted_custom_{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + num_custom_obj += 1 + elif isinstance(constant_val, torch.Tensor): + # Remove the parameterness of constant_val + if isinstance(constant_val, torch.nn.Parameter): + log.debug( + "%s created when tracing %s is a parameter. But " + "it's not registered with register_parameter(). export will treat it as a constant tensor", + str(node.target), + str(node.meta.get("stack_trace", "")), + ) + # We get the real data out of the parameter by disabling the surrounding fake mode. + with unset_fake_temporarily(): + constant_val = constant_val.data + constant_kind = InputKind.CONSTANT_TENSOR + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_tensor_{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + while constant_fqn in used_target_names: + num_tensor_constants += 1 + constant_name = f"lifted_tensor_{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + num_tensor_constants += 1 + else: + raise SpecViolationError( + f"getattr node {node} referencing unsupported type {type(constant_val)}" + ) + + with gm.graph.inserting_before(first_user_input): + # Insert the constant node before the first user input + const_placeholder_node = gm.graph.placeholder(constant_name) + # match target name with its node name in case there is name collision + # and suffix is added to node name in fx + const_placeholder_node.target = const_placeholder_node.name + + for k, v in node.meta.items(): + const_placeholder_node.meta[k] = v + + # Once the FQN has been used, remove nn_module_stack, stack_trace + const_placeholder_node.meta.pop("nn_module_stack") + const_placeholder_node.meta.pop("stack_trace", None) + + input_spec_arg: ArgumentSpec + if isinstance(constant_val, torch.Tensor): + if fake_mode is not None: + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + constant_val, static_shapes=True + ) + const_placeholder_node.meta["val"].constant = constant_val + else: + const_placeholder_node.meta["val"] = constant_val + input_spec_arg = TensorArgument(name=const_placeholder_node.name) + elif isinstance(constant_val, torch._C.ScriptObject): + class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn=class_fqn + ) + elif isinstance(constant_val, FakeScriptObject): + class_fqn = constant_val.script_class_name + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn, constant_val + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, + class_fqn=class_fqn, + fake_val=constant_val, + ) + else: + raise SpecViolationError( + f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}" + ) + + lifted_objs.add(constant_val, const_placeholder_node) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + + renamed_targets[node.name] = const_placeholder_node.name + + # Add the constant as a buffer to the graph signature + graph_signature.input_specs.insert( + first_user_input_loc, + InputSpec( + kind=constant_kind, + arg=input_spec_arg, + target=constant_fqn, + ), + ) + if constant_val in constant_attrs: + for fqn in constant_attrs[constant_val]: + all_constants[fqn] = constant_val + else: + all_constants[constant_fqn] = constant_val + first_user_input_loc += 1 + + for spec in graph_signature.output_specs: + if spec.arg.name in renamed_targets: + spec.arg.name = renamed_targets[spec.arg.name] + + return all_constants + + +def rewrite_script_object_meta( + gm: torch.fx.GraphModule, +) -> dict[str, _ConstantAttributeType]: + """When tracing, we produce a graph with FakeScriptObject in the + meta["val"]. + + For now, we rewrie meta["val"] to be a placeholder CustomObjArgument + """ + constants: dict[ + str, + _ConstantAttributeType, + ] = {} + for node in gm.graph.nodes: + if "val" not in node.meta: + continue + + old_meta = node.meta["val"] + + if isinstance(old_meta, torch.ScriptObject): + class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + elif isinstance(old_meta, FakeScriptObject): + class_fqn = old_meta.script_class_name # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn, old_meta) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + return constants + + +def _materialize_and_lift_constants( + gm: torch.fx.GraphModule, + export_graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> dict[str, _ConstantAttributeType]: + constants = rewrite_script_object_meta(gm) + constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) + return constants diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py new file mode 100644 index 0000000000000000000000000000000000000000..ceed7cd23aa0e953b99586052629668cc53c4bdd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py @@ -0,0 +1,36 @@ +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class _RemoveRuntimeAssertionsPass(PassBase): + """ + Remove runtime assertions inserted by the + _AddRuntimeAssertionsForInlineConstraintsPass. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target in [ + torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten._assert_tensor_metadata.default, + ]: + assert_async_node = node + if len(assert_async_node.users) > 0: + continue + module.graph.erase_node(assert_async_node) + # the upstream scalar_tensor <- {le, ge} <- sym_size + # linear chain of nodes of nodes is removed by the + # downstream dead code elimination + modified = True + + # We don't necessarily want to run DCE here because it could affect + # nodes that are in the module_call_graph attribute of the exported + # program. We will leave it to the pass caller to call DCE. + return PassResult(graph_module, modified) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..14ab3e817ed703cbe0844198deca5c06f2e6effc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -0,0 +1,189 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch._higher_order_ops.wrap import wrap_with_autocast + +from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +if TYPE_CHECKING: + from torch.export.graph_signature import ExportGraphSignature + + +def _is_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target + in [ + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + ] + ) + + +def _is_enter_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target is torch.amp.autocast_mode._enter_autocast + ) + + +def _is_exit_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target is torch.amp.autocast_mode._exit_autocast + ) + + +def _is_autocast_sub_mod(node: torch.fx.Node) -> bool: + """ + Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`. + """ + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target is torch.amp.autocast_mode._enter_autocast + ): + # TODO: check if current auto-cast type is the same as the args of + # _enter_autocast. If so, return False, i.e. do not create a submodule. + return True + return False + + +def _check_valid_autocast_block( + enter_autocast_node: torch.fx.Node, exit_autocast_node: torch.fx.Node +) -> None: + assert _is_enter_autocast_node(enter_autocast_node) + assert _is_exit_autocast_node(exit_autocast_node) + assert exit_autocast_node.args[0] == enter_autocast_node + + +def _replace_with_hop(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node) + if len(autocast_nodes) > 0: + assert len(autocast_nodes) > 1 # need at least an enter node and an exist node + enter_autocast_node = autocast_nodes[0] + exit_autocast_node = autocast_nodes[-1] + _check_valid_autocast_block(enter_autocast_node, exit_autocast_node) + + _replace_with_hop_helper(node, enter_autocast_node, wrap_with_autocast) + sub_graph.erase_node(exit_autocast_node) + sub_graph.erase_node(enter_autocast_node) + + +def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + split_autocast creates a new graph module that splits the input graph module into multiple submodules + based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module. + + Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are split + into a submodule. Nested autocast regions are not split. + `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well. + + Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph + module. Nodes marked with the same number are grouped into the same submodule. + A # 0 + enter_autocast # 1 + B # 1 + exit_autocast # 1 + C # 2 + enter_autocast # 3 + D # 3 + exit_autocast # 3 + E # 4 + """ + enter_autocast_node_stack: list[torch.fx.Node] = [] + first_node_after_outer_most_exit: bool = False + + def node_call_back(node: torch.fx.Node) -> bool: + nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit + increment_id = False + if first_node_after_outer_most_exit or ( + len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node) + ): + assert len(enter_autocast_node_stack) == 0 + first_node_after_outer_most_exit = False + increment_id = True + if _is_enter_autocast_node(node): + enter_autocast_node_stack.append(node) + elif _is_exit_autocast_node(node): + assert len(enter_autocast_node_stack) > 0 + last_enter_autocast_node = enter_autocast_node_stack.pop() + assert node.args[0] == last_enter_autocast_node + if len(enter_autocast_node_stack) == 0: + # next node should be in the next submodule since + # autocast block ends + first_node_after_outer_most_exit = True + return increment_id + + return sequential_split(gm, node_call_back) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Helper function for replace_autocast_with_hop_pass(). + Split the graph module into multiple subgraphs based on the autocast nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered + as a subgraph. + """ + need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # split_autocast returns a new graph module that could have different output + # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`. + new_gm = _split_autocast(gm) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node) -> None: + if _is_autocast_sub_mod(node): + _replace_with_hop(node) + else: + assert node.op == "call_module" + assert isinstance(node.target, str) + node_inline_(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_autocast_with_hop_pass( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2324d1f2cfa20c96003d3ae9e634784994648b10 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py @@ -0,0 +1,676 @@ +# mypy: allow-untyped-defs +import logging +import operator +from typing import Optional, Union + +import torch +import torch.export._trace +from torch._ops import OpOverload +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel, + dequantize_per_tensor, + quantize_per_tensor, +) +from torch.ao.quantization.utils import calculate_qmin_qmax +from torch.fx.graph_module import _assign_attr + + +log = logging.getLogger(__name__) + +# Those values will need to be carried over multiple operators. +_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None +_SCALE: Optional[Union[float, torch.fx.Node]] = None +_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None + + +def int_to_valid_dtype(val: int) -> torch.dtype: + from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import. + + if isinstance(val, torch.dtype): + return val + dtype = _TORCH_ENUM_TO_DTYPE[val] + if dtype == torch.quint8: + return torch.uint8 + elif dtype == torch.qint8: + return torch.int8 + return dtype + + +def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node: + return gm.graph.call_function(int_to_valid_dtype, (val,)) + + +def insert_quantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + return gm.graph.call_function( + quantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + + +def get_dequantized( + val: torch.Tensor, + scale: Union[float, torch.Tensor], + zero_point: Union[float, torch.Tensor], + qmin: Union[float, int], + qmax: Union[float, int], + dtype: torch.dtype, + axis: Optional[int], + qscheme: Optional[torch.qscheme], +) -> torch.Tensor: + if qscheme is torch.per_tensor_affine: + return dequantize_per_tensor( + val, + scale, # type: ignore[arg-type] + zero_point, # type: ignore[arg-type] + qmin, # type: ignore[arg-type] + qmax, # type: ignore[arg-type] + dtype, + ) + elif qscheme is torch.per_channel_affine: + return dequantize_per_channel( + val, + scale, # type: ignore[arg-type] + zero_point, # type: ignore[arg-type] + axis, # type: ignore[arg-type] + qmin, # type: ignore[arg-type] + qmax, # type: ignore[arg-type] + dtype, + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def insert_dequantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + axis_node: Optional[Union[int, torch.fx.Node]], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + if qscheme is torch.per_tensor_affine: + return gm.graph.call_function( + dequantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + elif qscheme is torch.per_channel_affine: + return gm.graph.call_function( + dequantize_per_channel, + ( + val_node, + scale_node, + zero_point_node, + axis_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def get_qmin_qmax(dtype: torch.dtype) -> tuple[Union[int, float], Union[int, float]]: + return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type] + + +def insert_qmin_qmax_node( + gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node] +) -> tuple[torch.fx.Node, torch.fx.Node]: + q_min_max_node = gm.graph.call_function( + calculate_qmin_qmax, (None, None, False, dtype_node, False) + ) + qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0)) + qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1)) + return qmin_node, qmax_node + + +def get_script_object( + gm: torch.nn.Module, node: torch.fx.Node +) -> torch._C.ScriptObject: + assert isinstance(node, torch.fx.Node) + assert node.op == "get_attr" + attr_name = node.target + assert isinstance(attr_name, str) + + mod = gm + for attr in attr_name.split("."): + mod = getattr(mod, attr) + assert isinstance(mod, torch._C.ScriptObject) + return mod + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm: torch.fx.GraphModule, + param_node: torch.fx.Node, +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + """Directly inline tensor from a get_attr fx node.""" + mod = get_script_object(gm, param_node) + w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined] + w_attr_name, b_attr_name = ( + f"dequantized_{param_node.target}_w", + f"dequantized_{param_node.target}_b", + ) + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm: torch.fx.GraphModule, + get_attr_to_weight_node: torch.fx.Node, + get_attr_to_bias_node: Optional[torch.fx.Node], +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + assert isinstance(get_attr_to_weight_node.target, str) + w_qtensor = getattr(gm, get_attr_to_weight_node.target) + w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w" + + if get_attr_to_bias_node is not None: + assert isinstance(get_attr_to_bias_node.target, str) + b_qtensor = getattr(gm, get_attr_to_bias_node.target) + b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b" + else: + b_qtensor, b_attr_name = None, "" + + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node( + gm: torch.fx.GraphModule, + w_qtensor: torch.Tensor, + b_qtensor: Optional[torch.Tensor], + w_attr_name: str, + b_attr_name: str, +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + w_tensor = get_tensor_from_qtensor(w_qtensor) + _assign_attr(w_tensor, gm, w_attr_name) + w_tensor_attr = gm.graph.get_attr(w_attr_name) + + if b_qtensor is not None: + b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False) + _assign_attr(b_tensor, gm, b_attr_name) + b_tensor_attr = gm.graph.get_attr(b_attr_name) + else: + b_tensor_attr = None + + return w_tensor_attr, b_tensor_attr + + +def get_tensor_from_qtensor( + qtensor: torch.Tensor, dequant: bool = True +) -> torch.Tensor: + # Manual conversion because qint8 is not used anymore. + if qtensor.dtype in [torch.qint8, torch.quint8]: + tensor = qtensor.int_repr() + else: + tensor = qtensor + + # Weights need dequantization with scaling and zero_point adjustment, but + # bias does not need that. + if dequant: + qscheme = qtensor.qscheme() + if qscheme == torch.per_channel_affine: + scale, zero_point, axis = ( + qtensor.q_per_channel_scales(), + qtensor.q_per_channel_zero_points(), + qtensor.q_per_channel_axis(), + ) + else: + scale, zero_point, axis = ( + qtensor.q_scale(), # type: ignore[assignment] + qtensor.q_zero_point(), # type: ignore[assignment] + None, + ) + dtype = tensor.dtype + qmin, qmax = get_qmin_qmax(dtype) + return get_dequantized( + tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme + ) + return tensor + + +def insert_fused_activation_node( + gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node +) -> torch.fx.Node: + if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]: + fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,)) + return fx_node + + +def _conv1d_op_with_squeeze( + inp: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +) -> torch.Tensor: + # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze + # operations before and after the conv2d operation to match the dimension of weights. + # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950 + s_inp = torch.ops.aten.unsqueeze(inp, 2) + conv1d_res = torch.ops.aten.conv2d( + s_inp, + weight, + bias, + stride, + padding, + dilation, + groups, + ) + uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2) + return uns_conv1d_res + + +def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Conv specific transformation function.""" + assert isinstance(node.target, torch._ops.OpOverload) + opname = node.target._opname + scale_node, zero_point_node = node.args[2], node.args[3] + + op_f = ( + torch.ops.aten.conv2d + if opname in ["conv2d", "conv2d_relu"] + else _conv1d_op_with_squeeze + ) + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using Conv2dPrepackParam from conv_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + op_f, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using ConvPrepackedParam. + param = get_script_object(gm, param_node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + op_f, + ( + inp_node, + param_0, + param_1, + param.stride(), # type: ignore[attr-defined] + param.padding(), # type: ignore[attr-defined] + param.dilation(), # type: ignore[attr-defined] + param.groups(), # type: ignore[attr-defined] + ), + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Linear specific transformation function.""" + scale_node, zero_point_node = node.args[2], node.args[3] + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using LinearPrepackParam from linear_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using LinearPackedParams. + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1) + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_op_where_last_two_arguments_are_scale_and_zero_point( + gm: torch.fx.GraphModule, node: torch.fx.Node +): + """ + This transformation function can be used for function where the last two + parameters are scale and zero point. Additionally, the function's parameters + do not need any unpacking. + """ + to_standard_op = { + "mul": torch.ops.aten.mul, + "mul_relu": torch.ops.aten.mul, + "add": torch.ops.aten.add, + "add_relu": torch.ops.aten.add, + "softmax": torch.ops.aten.softmax, + "cat": torch.ops.aten.cat, + "hardswish": torch.ops.aten.hardswish, + } + + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2])) + return op_res_node, scale_node, zero_point_node + + +def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Transform scalar overload for basic arithmetic.""" + to_standard_op = { + "mul": torch.ops.aten.mul.Scalar, + "add": torch.ops.aten.add.Scalar, + } + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_res_node = gm.graph.call_function(to_standard_op[opname], args) + return op_res_node, _SCALE, _ZERO_POINT + + +def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node): + """ + Transformation for functions under prepacked namespace, where they share + the same handling logic that [...]OpContext contains all parameters. + """ + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_f = None + if opname == "conv2d_clamp_run": + op_f = torch.ops.aten.conv2d + elif opname == "linear_clamp_run": + op_f = torch.ops.aten.linear + else: + raise RuntimeError(f"Invalid operator {opname}") + + assert isinstance(args[1], torch.fx.Node) + so = get_script_object(gm, args[1]) + + func_args = [] + func_args += [args[0]] + func_args += so.unpack()[:2] # type: ignore[attr-defined] + if opname == "conv2d_clamp_run": + func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:] + + op_res_node = gm.graph.call_function(op_f, tuple(func_args)) + return op_res_node + + +def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node): + args = node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function( + torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3]) + ) + op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0)) + return op_res_node, scale_node, zero_point_node + + +def fx_transform_quantized_op_to_standard_op( + gm: torch.fx.GraphModule, node: torch.fx.Node +) -> torch.fx.Node: + global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE + + assert isinstance(node.target, torch._ops.OpOverload) + opname, overload = node.target._opname, node.target._overloadname + + key = f"{opname}.{overload}" + opname_to_transform_f = { + "conv1d.new": _transform_conv_with_packedparam, + "conv1d_relu.new": _transform_conv_with_packedparam, + "conv1d.default": _transform_conv_with_packedparam, + "conv1d_relu.default": _transform_conv_with_packedparam, + "conv2d.new": _transform_conv_with_packedparam, + "conv2d_relu.new": _transform_conv_with_packedparam, + "conv2d.default": _transform_conv_with_packedparam, + "conv2d_relu.default": _transform_conv_with_packedparam, + "linear.default": _transform_linear_with_packedparam, + "linear_relu.default": _transform_linear_with_packedparam, + "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "batch_norm2d.default": _transform_batch_norm, + "mul.Scalar": _transform_scalar_arithmetic, + "add.Scalar": _transform_scalar_arithmetic, + } + + if f"{key}" not in opname_to_transform_f: + raise RuntimeError(f"Unsupported quantized op during transformation: {key}") + + op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node) + + # Add fused activation layer. + op_res_node = insert_fused_activation_node(gm, opname, op_res_node) + _SCALE, _ZERO_POINT = scale_node, zero_point_node + + assert _INPUT_Q_DTYPE is not None + qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE) + q_fx_node = insert_quantized_node( + gm, + op_res_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + None, + torch.per_tensor_affine, + ) + return dq_fx_node + + +def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): + """ + Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with + PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv). + + Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y + + After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y + + (qd == quantized_decomposed library, q = quantize, dq = dequantize) + ^ + | + getattr(w), getattr(b) from Conv2dParamPrepack + + During each iteration, the transformation spits out the transformed operator, its quantized output, + and its dequantized value together. We did this because dequantization need to use the + scale and zero point parameters from the quantization to recover the approximate original value. After each + iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear). + + For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject. + During the transformation, we unpack those objects, get their dequantized tensor, populate those + as attributes to the module, and use getattr to access them. + + One exception in the transformation is conv_prepack and linear_prepack. Those calls pack + weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls. + During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the + quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters + to the operator by converting them to a getattr fx.node. + + For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear + without the need of doing de/quantization. + + Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization + data type, which is the same across the entire program, but it only shows up in the very first quantization + call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar. + """ + + global _INPUT_Q_DTYPE + + quantized = False + + last_quantized_node = None + # pyrefly: ignore [bad-assignment] + for node in gm.graph.nodes: + if isinstance(node.target, OpOverload): + with gm.graph.inserting_before(node): + namespace, opname = node.target.namespace, node.target._opname + if namespace == "quantized" and opname not in [ + "conv_prepack", + "linear_prepack", + ]: + quantized = True + fx_node = fx_transform_quantized_op_to_standard_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "prepacked": + quantized = True + fx_node = _transform_prepacked_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "aten" and opname == "quantize_per_tensor": + inp_node, scale_node, zero_point_node, dtype_node = node.args + dtype_node = fx_enum_to_dtype(gm, dtype_node) + _INPUT_Q_DTYPE = dtype_node + qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node) + q_fx_node = insert_quantized_node( + gm, + inp_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + None, + torch.per_tensor_affine, + ) + node.replace_all_uses_with(dq_fx_node) + last_quantized_node = dq_fx_node + elif namespace == "aten" and opname == "dequantize": + assert last_quantized_node is not None + node.replace_all_uses_with(last_quantized_node) + else: + last_quantized_node = node + + # Post-processing again to remove legacy ScriptObjects and quantizated tensors + # stored as attributes or in the buffer. This is used to clean up the GraphModule + # to not trigger tracing errors like missing __obj_flatten__ functions. + def _clean_attr(mod: torch.nn.Module): + for submod in mod.modules(): + attr_names_to_clean = set() + for k, v in submod.__dict__.items(): + if isinstance(v, torch.ScriptObject): + attr_names_to_clean.add(k) + if k == "_buffers": + buffer_name_to_clean = set() + # pyrefly: ignore [missing-attribute] + for b_name, b_value in v.items(): + if isinstance(b_value, torch.Tensor) and b_value.dtype in [ + torch.qint8, + torch.quint8, + ]: + buffer_name_to_clean.add(b_name) + for b_name in buffer_name_to_clean: + # pyrefly: ignore [missing-attribute] + v.pop(b_name, None) + for attr_name in attr_names_to_clean: + delattr(submod, attr_name) + + if quantized: + """ + TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily + bypass test cases. + + The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing + will throw errors. However, the current way of SetAttr does inplace update to attributes, so + this pass regard them as dead code and remove them. Below is an example of GraphModule before + and after the dead code elimination pass. + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data = self.data; data = None + data_1 = self.data + add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None + data_2 = self.data + copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + """ + gm.graph.eliminate_dead_code() + _clean_attr(gm) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5a15a5950575527b9beca532e4b0229b2603c1a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled + +from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +if TYPE_CHECKING: + from torch.export.graph_signature import ExportGraphSignature + + +def _is_set_grad_enabled_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target is torch._C._set_grad_enabled + ) + + +def _is_set_grad_enabled_sub_mod( + node: torch.fx.Node, omit_if_same_with_ambient: bool = False +) -> bool | torch.Tensor: + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target is torch._C._set_grad_enabled + ): + return ( + first_non_ph.args[0] != torch.is_grad_enabled() + if omit_if_same_with_ambient + else True + ) + return False + + +def _replace_with_hop(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node) + if len(set_grad_nodes) > 0: + assert len(set_grad_nodes) == 1 + set_grad_node = set_grad_nodes[0] + _replace_with_hop_helper(node, set_grad_node, wrap_with_set_grad_enabled) + sub_graph.erase_node(set_grad_node) + + +def _remove_set_grad_and_inline(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + nodes_map( + sub_graph.nodes, + lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n, + ) + node_inline_(node) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Helper function for replace_set_grad_with_hop_pass(). + Split the graph module into multiple subgraphs based on the set_grad_enabled nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + """ + need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # sequential_split returns a new graph module that could have different output + # args names. We need to fix the graph signature. + new_gm = sequential_split(gm, _is_set_grad_enabled_node) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): + if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): + _replace_with_hop(node) + else: + _remove_set_grad_and_inline(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_set_grad_with_hop_pass( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..489bc19ed1d50d13f7bc8d7cd73f940bb34f451d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch._export.error import InternalError +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse +from torch._ops import HigherOrderOperator, OpOverload + + +__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] + + +_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = { + torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, +} + + +def is_view_op(schema: torch._C.FunctionSchema) -> bool: + if len(schema.arguments) == 0: + return False + alias_info = schema.arguments[0].alias_info + return (alias_info is not None) and (not alias_info.is_write) + + +def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]: + if is_view_op(schema) and schema.name.startswith("aten::"): + view_op_name = schema.name.split("::")[1] + view_op_overload = ( + schema.overload_name if schema.overload_name != "" else "default" + ) + view_copy_op_name = view_op_name + "_copy" + if not hasattr(torch.ops.aten, view_copy_op_name): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name) + + if not hasattr(view_copy_op_overload_packet, view_op_overload): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + return getattr(view_copy_op_overload_packet, view_op_overload) + + return None + + +class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Our backend expects pure functional operators. For efficiency + purposes, we keep view ops around while functionalizing the exported + program. This pass replaces view ops with view copy ops for backends that + need AOT memory planning. + """ + + def call_operator(self, op, args, kwargs, meta): + if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: + return super().call_operator( + (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta + ) + + if isinstance(op, HigherOrderOperator): + return super().call_operator(op, args, kwargs, meta) + + if view_copy_op := get_view_copy_of_view_op(op._schema): + return super().call_operator(view_copy_op, args, kwargs, meta) + + return super().call_operator(op, args, kwargs, meta) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py new file mode 100644 index 0000000000000000000000000000000000000000..862244aac8837fd10c3d86838d81db6bd0c62a7e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py @@ -0,0 +1,190 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import copy +import operator +from typing import TYPE_CHECKING + +import torch + +from ..utils import node_replace_, nodes_map + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch._ops import HigherOrderOperator + from torch.export.graph_signature import ExportGraphSignature + + +def _replace_with_hop_helper( + node: torch.fx.Node, + enter_block_node: torch.fx.Node, + wrap_hoo: HigherOrderOperator, +) -> None: + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + + def set_hoo_node_meta(call_func_node): + call_func_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + call_func_node.meta["torch_fn"] = ( + f"{wrap_hoo.__name__}", + # pyrefly: ignore [missing-attribute] + f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}", + ) + if isinstance(output_args, (tuple, list)): + call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args) + elif isinstance(output_args, torch.fx.Node): + call_func_node.meta["val"] = (output_args.meta["val"],) + + with graph.inserting_before(node): + get_attr_node = graph.get_attr(node.target) + get_attr_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + output_node = next(iter(reversed(sub_gm.graph.nodes)), None) + # Split_module pass intentionally doesn't add output node + # if the graph doesn't return anything. + # TODO (tmanlaibaatar) Figure out if this is right behaviour + # for split_module + if isinstance(output_node, torch.fx.Node) and output_node.op != "output": + output_node = None + if output_node is not None: + assert len(output_node.args) == 1 + output_args = output_node.args[0] + enter_block_node_args = enter_block_node.args + if isinstance(output_args, (tuple, list)): + call_func_node = graph.call_function( + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + ) + # Create the metadata + set_hoo_node_meta(call_func_node) + node_replace_(node, call_func_node) + + # Rename the name of getitem nodes to the actual name of its contents + # for passing verifier and better readability, also propagate metadata + for get_item_node in call_func_node.users: + idx: int = get_item_node.args[1] # type: ignore[assignment] + output_node = output_args[idx] + get_item_node._rename(output_node.name) + get_item_node.meta = output_node.meta + + elif isinstance(output_args, torch.fx.Node): + call_func_node = graph.create_node( + "call_function", + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + output_args.name, + ) + # Modify the subgraph to output a singleton list. + output_node.args = ((output_args,),) + # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph. + get_item_node = graph.create_node( + "call_function", + operator.getitem, + (call_func_node, 0), + {}, + ) + # Create the metadata + get_item_node.meta = output_args.meta + set_hoo_node_meta(call_func_node) + node_replace_(node, get_item_node) + else: + raise NotImplementedError( + f"replace_with_hop_pass doesn't support output type {type(output_args)}" + ) + else: + # TODO (shangdiy): remove this line, since the export graph can be non-functional + node.graph.erase_node(node) + + +def _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature | None, + maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None], +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Helper function for replacing graph nodse with higher order nodes. + For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`. + """ + # new_gm is a new graph module that could have different output args names. + # We need to fix the graph signature. + replace_ctx = contextlib.nullcontext() + new_signature = None + if graph_signature is not None: + # Cannot deep copy a real ScriptObject, which is referenced + # in the FakeScriptObject. Copy should be good enough to guard + # against accidental mutation to original graph_signature. + new_signature = copy.copy(graph_signature) + new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output"))) + assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len( + new_signature.output_specs + ) + for arg_node, out_spec in zip( + new_gm_out_node.args[0], new_signature.output_specs + ): + if arg_node is None: + assert out_spec.arg.value is None # type: ignore[union-attr] + elif ( + isinstance(arg_node, torch.fx.Node) + and out_spec.arg.name != arg_node.name + ): + out_spec.arg.name = arg_node.name + + replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment] + + with replace_ctx: + nodes_map( + list(new_gm.graph.nodes), + lambda node: ( + maybe_inline_or_replace_with_hop(node) + if node.op == "call_module" + else node + ), + ) + new_gm.recompile() + new_gm.graph.lint() + return new_gm, new_signature + + +def _replace_with_hop_pass_helper( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature | None, + sequential_split_and_maybe_inline_subgraphs: Callable[ + [torch.fx.GraphModule, ExportGraphSignature | None], + tuple[torch.fx.GraphModule, ExportGraphSignature | None], + ], +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs( + gm, graph_signature + ) + # recursively call + for node in new_gm.graph.nodes: + if node.op == "get_attr": + subgm = getattr(new_gm, node.target) + if not isinstance(subgm, torch.fx.GraphModule): + continue + new_subgm, _ = _replace_with_hop_pass_helper( + subgm, + None, + sequential_split_and_maybe_inline_subgraphs, + ) + setattr(new_gm, node.target, new_subgm) + + new_gm.recompile() + new_gm.graph.lint() + return new_gm, new_signature diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d002874d48245d2053c9bdc72bca02ebca606e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py @@ -0,0 +1,324 @@ +import dataclasses +from typing import Any, Optional, Union + +import torch +from torch._dynamo.exc import UserError, UserErrorType +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _DerivedDim, + _DimHint, + _tree_map_with_path, + Dim, +) +from torch.utils._pytree import tree_map + +from .serialize import _dataclass_to_dict + + +@dataclasses.dataclass +class RootDim: + """ + This represents a Dim object. + """ + + min: int + max: Union[int, None] + derived: list[str] + + +@dataclasses.dataclass +class DynamicShapesSpec: + """ + This stores a dynamic_shapes spec for de/serialization. + """ + + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None] + dims: dict[str, RootDim] + + +def _postprocess_serialized_shapes( + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + dims: dict[str, dict[str, Union[int, list[str], None]]], + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, dict[str, Any]]: + """ + Sorts dims and dumps to dictionary format. + """ + from torch.utils._sympy.numbers import int_oo + + dims = { + k: RootDim( + min=v["min"], # type: ignore[arg-type] + max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type] + derived=sorted(v["derived"]), # type: ignore[arg-type] + ) + for k, v in sorted(dims.items()) + } + # pyrefly: ignore [bad-argument-type] + spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) + if to_dict: + return _dataclass_to_dict(spec) + else: + return spec + + +def _dump_dynamic_shapes( + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, dict[str, Any]]: + """ + Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. + Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". + Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones). + + dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export(): + - Each tensor input is represented with a list of values, non-tensor inputs with None. + - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings. + - static dimensions are represented with ints. + + dims: A dictionary mapping each symbol name to the min/max range and derived dim names. + + For example: + ``` + dx = Dim("dx", min=4, max=16) + dy = dx + 1 + + inputs = ( + [ + torch.randn(4, 4), + torch.randn(5, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", + ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (Dim.STATIC,), + "c": None, + "d": None, + } + out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True) + ``` + would generate the following output: + ``` + { + "dynamic_shapes": ( + [ + ["dx", 4], + ["dx + 1", 4], + ], + ["_DimHint.STATIC"], + ["_DimHint.STATIC", "_DimHint.STATIC"], + None, + ), + "dims": { + "dx": { + "min": 4, + "max": 16, + "derived": ["dx + 1"], + }, + }, + } + ``` + """ + dims: dict[str, dict[str, Any]] = {} + + def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] + """ + Helps standardize the dynamic_shapes tree structure we serialize, + returning lists for each tensor shape, handling tensor-level Nones. + """ + if not isinstance(tensor, torch.Tensor): + return None + if shape is None: + return [Dim.STATIC] * len(tensor.shape) + + out = [] + if isinstance(shape, dict): + for i, s in enumerate(tensor.shape): + out.append(s if shape.get(i) is None else shape.get(i)) + else: + assert isinstance(shape, (tuple, list)) + for i, s in enumerate(tensor.shape): + out.append(s if shape[i] is None else shape[i]) + return out + + def _track_dim_from_dims( + val: Union[None, int, _DimHint, Dim], + ) -> Union[None, int, str]: + """ + Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. + """ + if val is None or isinstance(val, int): # non-tensor input or static + return val + if isinstance(val, _DimHint): # store enum as string + return val.__class__.__name__ + "." + val.type.name + + assert isinstance(val, Dim) + + # track root dim + root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined] + if root.__name__ not in dims: + dims[root.__name__] = { + "min": root.min, # type: ignore[attr-defined,union-attr] + "max": root.max, # type: ignore[attr-defined,union-attr] + "derived": set(), + } + + # track derived dims + if isinstance(val, _DerivedDim): + dims[root.__name__]["derived"].add(val.__name__) + + return val.__name__ + + if dynamic_shapes is None: + return {"dynamic_shapes": None, "dims": {}} + + # convert to tuple of specs, for each arg/kwarg + kwargs = kwargs or {} + if isinstance(dynamic_shapes, dict): + dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment] + # pyrefly: ignore [bad-assignment, bad-argument-type] + dynamic_shapes = tuple(dynamic_shapes) + combined_args = tuple(args) + tuple(kwargs.values()) + + # run same check when we're processing shapes for export - is this too lazy? + _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type] + + tree_shapes = _tree_map_with_path( + _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs" + ) + serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes) + return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict) + + +def _load_dynamic_shapes( + spec: Union[DynamicShapesSpec, dict[str, Any]], + from_dict: Optional[bool] = False, +) -> Union[dict[str, Any], tuple[Any], list[Any], None]: + """ + Utility function for dynamic shapes serialization. + Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). + """ + import sympy + + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + if from_dict: + if not isinstance(spec, dict): + raise UserError( + UserErrorType.INVALID_INPUT, + f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}", + ) + if sorted(spec.keys()) != ["dims", "dynamic_shapes"]: + raise UserError( + UserErrorType.INVALID_INPUT, + "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, " + f"instead found {spec.keys()}", + ) + dims = {} + for k, v in spec["dims"].items(): + if not isinstance(k, str): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}", + ) + if sorted(v.keys()) != ["derived", "max", "min"]: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, " + f"instead found {v.keys()}", + ) + if not isinstance(v["min"], int): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}", + ) + if not isinstance(v["max"], int) or v["max"] is None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}", + ) + if not isinstance(v["derived"], list) or any( + not isinstance(d, str) for d in v["derived"] + ): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, " + f"got {k}: {v['derived']}", + ) + dims[k] = RootDim(**v) + dynamic_shapes = spec["dynamic_shapes"] + else: + if not isinstance(spec, DynamicShapesSpec): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}", + ) + dims = spec.dims + dynamic_shapes = spec.dynamic_shapes + + if dynamic_shapes is None: + return None + + dim_cache = {} + for name, info in dims.items(): + symbol = sympy.sympify(name) + if not isinstance(symbol, sympy.Symbol): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be symbols, got {name}", + ) + dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim + for _expr in info.derived: + expr = sympy.sympify(_expr) + if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions in to have {name} as the only free symbol, got {expr}", + ) + if not _is_supported_equivalence(expr): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions to be linear expressions, got {expr}", + ) + modulus, remainder = sympy.polys.polytools.div(expr, symbol) + ddim = dim_cache[name] + if modulus != 1: + ddim = int(modulus) * ddim # type: ignore[assignment, operator] + if remainder != 0: + ddim = ddim + int(remainder) # type: ignore[assignment, operator] + dim_cache[_expr] = ddim # cache derived dims + + def deserialize_shape( + val: Union[None, int, str], + ) -> Union[None, int, Dim, _DimHint]: + if val is None or isinstance(val, int): + return val + elif val == "_DimHint.AUTO": + return _DimHint.AUTO() + elif val == "_DimHint.DYNAMIC": + return _DimHint.DYNAMIC() + elif val == "_DimHint.STATIC": + return _DimHint.STATIC() + if not isinstance(val, str): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, " + f" or derived expressions, got {val}", + ) + if val not in dim_cache: + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + f"got {val} which is not in {dims.keys()}", + ) + return dim_cache[val] # type: ignore[return-value] + + return tree_map(deserialize_shape, dynamic_shapes) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift new file mode 100644 index 0000000000000000000000000000000000000000..155f52595740c5a1d57b8071a11b509ef16d5fce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift @@ -0,0 +1,377 @@ +// @generated by update_schema.py +// checksum<<0e870e558fb4362f69b825842ab606cf0becd10a008003ac676156becf20b65b>> + +namespace py3 torch._export +namespace cpp2 torch._export.schema + +enum ArgumentKind { + UNKNOWN = 0, + POSITIONAL = 1, + KEYWORD = 2, +} + + +enum Layout { + Unknown = 0, + SparseCoo = 1, + SparseCsr = 2, + SparseCsc = 3, + SparseBsr = 4, + SparseBsc = 5, + _mkldnn = 6, + Strided = 7, +} + + +enum MemoryFormat { + Unknown = 0, + ContiguousFormat = 1, + ChannelsLast = 2, + ChannelsLast3d = 3, + PreserveFormat = 4, +} + + +enum ScalarType { + UNKNOWN = 0, + BYTE = 1, + CHAR = 2, + SHORT = 3, + INT = 4, + LONG = 5, + HALF = 6, + FLOAT = 7, + DOUBLE = 8, + COMPLEXHALF = 9, + COMPLEXFLOAT = 10, + COMPLEXDOUBLE = 11, + BOOL = 12, + BFLOAT16 = 13, + UINT16 = 28, + FLOAT8E4M3FN = 29, + FLOAT8E5M2 = 30, + FLOAT8E4M3FNUZ = 31, + FLOAT8E5M2FNUZ = 32, +} + + +struct Device { + 10: string type; + 20: optional i64 index; +} + +union SymExprHint { + 10: i64 as_int; + 20: bool as_bool; + 30: double as_float; +} + +struct SymExpr { + 10: string expr_str; + 20: optional SymExprHint hint; +} + +union SymInt { + 10: SymExpr as_expr; + 20: i64 as_int; +} + +union SymFloat { + 10: SymExpr as_expr; + 20: double as_float; +} + +union SymBool { + 10: SymExpr as_expr; + 20: bool as_bool; +} + +struct TensorMeta { + 10: ScalarType dtype; + 20: list sizes; + 30: bool requires_grad; + 40: Device device; + 50: list strides; + 60: SymInt storage_offset; + 70: Layout layout; +} + +union SymIntArgument { + 10: string as_name; + 20: i64 as_int; +} + +union SymFloatArgument { + 10: string as_name; + 20: double as_float; +} + +union SymBoolArgument { + 10: string as_name; + 20: bool as_bool; +} + +struct TensorArgument { + 10: string name; +} + +struct TokenArgument { + 10: string name; +} + +union OptionalTensorArgument { + 20: TensorArgument as_tensor; + 10: bool as_none; +} + +struct GraphArgument { + 10: string name; + 20: Graph graph; +} + +struct CustomObjArgument { + 10: string name; + 20: string class_fqn; +} + +struct ComplexValue { + 10: double real; + 20: double imag; +} + +union Argument { + 10: bool as_none; + 20: TensorArgument as_tensor; + 30: list as_tensors; + 50: i64 as_int; + 70: list as_ints; + 80: double as_float; + 90: list as_floats; + 100: string as_string; + 101: list as_strings; + 110: SymIntArgument as_sym_int; + 120: list as_sym_ints; + 130: ScalarType as_scalar_type; + 140: MemoryFormat as_memory_format; + 150: Layout as_layout; + 160: Device as_device; + 170: bool as_bool; + 180: list as_bools; + 182: SymBoolArgument as_sym_bool; + 184: list as_sym_bools; + 200: GraphArgument as_graph; + 190: list as_optional_tensors; + 210: CustomObjArgument as_custom_obj; + 220: string as_operator; + 230: SymFloatArgument as_sym_float; + 240: list as_sym_floats; + 250: OptionalTensorArgument as_optional_tensor; + 260: ComplexValue as_complex; + 280: list> as_int_lists; + 290: map as_string_to_argument; +} + +struct NamedArgument { + 10: string name; + 20: Argument arg; + 30: optional ArgumentKind kind; +} + +struct Node { + 10: string target; + 20: list inputs; + 30: list outputs; + 40: map metadata; + 50: optional bool is_hop_single_tensor_return; +} + +struct Graph { + 10: list inputs; + 20: list outputs; + 30: list nodes; + 40: map tensor_values; + 50: map sym_int_values; + 60: map sym_bool_values; + 70: bool is_single_tensor_return; + 80: map custom_obj_values; + 90: map sym_float_values; +} + +struct UserInputSpec { + 10: Argument arg; +} + +union ConstantValue { + 10: bool as_none; + 20: i64 as_int; + 30: double as_float; + 40: string as_string; + 50: bool as_bool; +} + +struct InputToConstantInputSpec { + 10: string name; + 20: ConstantValue value; +} + +struct InputToParameterSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct InputToBufferSpec { + 10: TensorArgument arg; + 20: string buffer_name; + 30: bool persistent; +} + +struct InputToTensorConstantSpec { + 10: TensorArgument arg; + 20: string tensor_constant_name; +} + +struct InputToCustomObjSpec { + 10: CustomObjArgument arg; + 20: string custom_obj_name; +} + +struct InputTokenSpec { + 10: TokenArgument arg; +} + +union InputSpec { + 10: UserInputSpec user_input; + 20: InputToParameterSpec parameter; + 30: InputToBufferSpec buffer; + 40: InputToTensorConstantSpec tensor_constant; + 50: InputToCustomObjSpec custom_obj; + 70: InputTokenSpec token; + 60: InputToConstantInputSpec constant_input; +} + +struct UserOutputSpec { + 10: Argument arg; +} + +struct LossOutputSpec { + 10: TensorArgument arg; +} + +struct BufferMutationSpec { + 10: TensorArgument arg; + 20: string buffer_name; +} + +struct ParameterMutationSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct GradientToParameterSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct GradientToUserInputSpec { + 10: TensorArgument arg; + 20: string user_input_name; +} + +struct UserInputMutationSpec { + 10: TensorArgument arg; + 20: string user_input_name; +} + +struct OutputTokenSpec { + 10: TokenArgument arg; +} + +union OutputSpec { + 10: UserOutputSpec user_output; + 20: LossOutputSpec loss_output; + 30: BufferMutationSpec buffer_mutation; + 40: GradientToParameterSpec gradient_to_parameter; + 50: GradientToUserInputSpec gradient_to_user_input; + 60: UserInputMutationSpec user_input_mutation; + 70: OutputTokenSpec token; + 80: ParameterMutationSpec parameter_mutation; +} + +struct GraphSignature { + 10: list input_specs; + 20: list output_specs; +} + +struct RangeConstraint { + 10: optional i64 min_val; + 20: optional i64 max_val; +} + +struct ModuleCallSignature { + 10: list inputs; + 20: list outputs; + 30: string in_spec; + 40: string out_spec; + 50: optional list forward_arg_names; +} + +struct ModuleCallEntry { + 10: string fqn; + 30: optional ModuleCallSignature signature; +} + +struct NamedTupleDef { + 10: list field_names; +} + +struct GraphModule { + 10: Graph graph; + 50: GraphSignature signature; + 60: list module_call_graph; + 40: map metadata; + 70: map treespec_namedtuple_fields; +} + +struct SchemaVersion { + 10: i64 major; + 20: i64 minor; +} + +struct ExportedProgram { + 10: GraphModule graph_module; + 20: map opset_version; + 30: map range_constraints; + 60: SchemaVersion schema_version; + 70: list verifiers; + 80: string torch_version; + 90: list guards_code; +} + +struct PayloadMeta { + 10: string path_name; + 20: bool is_param; + 30: bool use_pickle; + 40: optional TensorMeta tensor_meta; +} + +struct PayloadConfig { + 10: map config; +} + +struct AOTInductorModelPickleData { + 1: string library_basename; + 2: list input_names; + 3: list output_names; + 4: optional i64 floating_point_input_dtype; + 5: optional i64 floating_point_output_dtype; + 6: optional bool aot_inductor_model_is_cpu; +} + +struct ExternKernelNode { + 10: string name; + 20: Node node; +} + +struct ExternKernelNodes { + 10: list nodes; +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..0d95ca32e6455ad2e8b13e1274a39a9ae0e78fd5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema.py @@ -0,0 +1,520 @@ +# NOTE: This is a placeholder for iterating on export serialization schema design. +# Anything is subject to change and no guarantee is provided at this point. + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Annotated, Optional + +from torch._export.serde.union import _Union, _union_dataclass + + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (8, 15) +TREESPEC_VERSION = 1 + + +# NOTE: If you updated the schema, please run `scripts/export/update_schema.py` +# to update the auto generated files. +class ScalarType(IntEnum): + UNKNOWN = 0 + BYTE = 1 + CHAR = 2 + SHORT = 3 + INT = 4 + LONG = 5 + HALF = 6 + FLOAT = 7 + DOUBLE = 8 + COMPLEXHALF = 9 + COMPLEXFLOAT = 10 + COMPLEXDOUBLE = 11 + BOOL = 12 + BFLOAT16 = 13 + UINT16 = 28 + FLOAT8E4M3FN = 29 + FLOAT8E5M2 = 30 + FLOAT8E4M3FNUZ = 31 + FLOAT8E5M2FNUZ = 32 + + +class Layout(IntEnum): + Unknown = 0 + SparseCoo = 1 + SparseCsr = 2 + SparseCsc = 3 + SparseBsr = 4 + SparseBsc = 5 + _mkldnn = 6 + Strided = 7 + + +class MemoryFormat(IntEnum): + Unknown = 0 + ContiguousFormat = 1 + ChannelsLast = 2 + ChannelsLast3d = 3 + PreserveFormat = 4 + + +@dataclass +class Device: + type: Annotated[str, 10] + index: Annotated[Optional[int], 20] = None + + +@_union_dataclass +class SymExprHint(_Union): + as_int: Annotated[int, 10] + as_bool: Annotated[bool, 20] + as_float: Annotated[float, 30] + + +# This is for storing the symbolic expressions behind symints/symfloats/symbools +# For example, we can get something like +# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) +# if we also have the hint that s0 and s1 are both 2. +@dataclass +class SymExpr: + expr_str: Annotated[str, 10] + hint: Annotated[Optional[SymExprHint], 20] = None + + +@_union_dataclass +class SymInt(_Union): + as_expr: Annotated[SymExpr, 10] + as_int: Annotated[int, 20] + + +@_union_dataclass +class SymFloat(_Union): + as_expr: Annotated[SymExpr, 10] + as_float: Annotated[float, 20] + + +@_union_dataclass +class SymBool(_Union): + as_expr: Annotated[SymExpr, 10] + as_bool: Annotated[bool, 20] + + +@dataclass +class TensorMeta: + dtype: Annotated[ScalarType, 10] + sizes: Annotated[list[SymInt], 20] + requires_grad: Annotated[bool, 30] + device: Annotated[Device, 40] + strides: Annotated[list[SymInt], 50] + storage_offset: Annotated[SymInt, 60] + layout: Annotated[Layout, 70] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymInts. +# The "as_int" field is used in the case where we have a list containing a mix +# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to +# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints +# to the "as_int" field. +@_union_dataclass +class SymIntArgument(_Union): + as_name: Annotated[str, 10] + as_int: Annotated[int, 20] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymFloats. +# The "as_float" field is used in the case where we have a list containing a mix +# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to +# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints +# to the "as_float" field. +@_union_dataclass +class SymFloatArgument(_Union): + as_name: Annotated[str, 10] + as_float: Annotated[float, 20] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymBools. +# The "as_bool" field is used in the case where we have a list containing a mix +# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to +# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools +# to the "as_bool" field. +@_union_dataclass +class SymBoolArgument(_Union): + as_name: Annotated[str, 10] + as_bool: Annotated[bool, 20] + + +@dataclass +class TensorArgument: + name: Annotated[str, 10] + + +@dataclass +class TokenArgument: + name: Annotated[str, 10] + + +# This is use for storing the contents of a list which contain optional tensors +# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the +# type List[OptionalTensorArgument], with tensor values serialized to the +# "as_tensor" field, and None values serialized to the "as_none" field. +@_union_dataclass +class OptionalTensorArgument(_Union): + as_tensor: Annotated[TensorArgument, 20] + as_none: Annotated[bool, 10] + + +@dataclass +class GraphArgument: + name: Annotated[str, 10] + graph: Annotated["Graph", 20] + + +@dataclass +class CustomObjArgument: + name: Annotated[str, 10] + class_fqn: Annotated[str, 20] + + +@dataclass +class ComplexValue: + real: Annotated[float, 10] + imag: Annotated[float, 20] + + +# This is actually a union type +@_union_dataclass +class Argument(_Union): + as_none: Annotated[bool, 10] + as_tensor: Annotated[TensorArgument, 20] + as_tensors: Annotated[list[TensorArgument], 30] + as_int: Annotated[int, 50] + as_ints: Annotated[list[int], 70] + as_float: Annotated[float, 80] + as_floats: Annotated[list[float], 90] + as_string: Annotated[str, 100] + as_strings: Annotated[list[str], 101] + as_sym_int: Annotated[SymIntArgument, 110] + as_sym_ints: Annotated[list[SymIntArgument], 120] + as_scalar_type: Annotated[ScalarType, 130] + as_memory_format: Annotated[MemoryFormat, 140] + as_layout: Annotated[Layout, 150] + as_device: Annotated[Device, 160] + as_bool: Annotated[bool, 170] + as_bools: Annotated[list[bool], 180] + as_sym_bool: Annotated[SymBoolArgument, 182] + as_sym_bools: Annotated[list[SymBoolArgument], 184] + as_graph: Annotated[GraphArgument, 200] + as_optional_tensors: Annotated[list[OptionalTensorArgument], 190] + as_custom_obj: Annotated[CustomObjArgument, 210] + as_operator: Annotated[str, 220] + as_sym_float: Annotated[SymFloatArgument, 230] + as_sym_floats: Annotated[list[SymFloatArgument], 240] + as_optional_tensor: Annotated[OptionalTensorArgument, 250] + as_complex: Annotated[ComplexValue, 260] + as_int_lists: Annotated[list[list[int]], 280] + as_string_to_argument: Annotated[dict[str, "Argument"], 290] + + +class ArgumentKind(IntEnum): + UNKNOWN = 0 + POSITIONAL = 1 + KEYWORD = 2 + + +@dataclass +class NamedArgument: + # Argument name from the operator schema + name: Annotated[str, 10] + arg: Annotated[Argument, 20] + kind: Annotated[Optional[ArgumentKind], 30] = None + + +@dataclass +class Node: + target: Annotated[str, 10] + inputs: Annotated[list[NamedArgument], 20] + outputs: Annotated[list[Argument], 30] + metadata: Annotated[dict[str, str], 40] + is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None + + +@dataclass +class Graph: + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] + nodes: Annotated[list[Node], 30] + tensor_values: Annotated[dict[str, TensorMeta], 40] + sym_int_values: Annotated[dict[str, SymInt], 50] + sym_bool_values: Annotated[dict[str, SymBool], 60] + # This is for deserializing the submodule graphs from higher order ops + # (ex. cond, map) where single tensor returns will just return a single + # tensor, rather than following export schema and returning a singleton + # list. + is_single_tensor_return: Annotated[bool, 70] = False + custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field( + default_factory=dict + ) + sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict) + + +@dataclass +class UserInputSpec: + # Actually, only tensors and SymInts are allowed here + arg: Annotated[Argument, 10] + + +@_union_dataclass +class ConstantValue(_Union): + as_none: Annotated[bool, 10] + as_int: Annotated[int, 20] + as_float: Annotated[float, 30] + as_string: Annotated[str, 40] + as_bool: Annotated[bool, 50] + + +@dataclass +class InputToConstantInputSpec: + name: Annotated[str, 10] + value: Annotated[ConstantValue, 20] + + +@dataclass +class InputToParameterSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class InputToBufferSpec: + arg: Annotated[TensorArgument, 10] + buffer_name: Annotated[str, 20] + persistent: Annotated[bool, 30] + + +@dataclass +class InputToTensorConstantSpec: + arg: Annotated[TensorArgument, 10] + tensor_constant_name: Annotated[str, 20] + + +@dataclass +class InputToCustomObjSpec: + arg: Annotated[CustomObjArgument, 10] + custom_obj_name: Annotated[str, 20] + + +@dataclass +class InputTokenSpec: + arg: Annotated[TokenArgument, 10] + + +@_union_dataclass +class InputSpec(_Union): + user_input: Annotated[UserInputSpec, 10] + parameter: Annotated[InputToParameterSpec, 20] + buffer: Annotated[InputToBufferSpec, 30] + tensor_constant: Annotated[InputToTensorConstantSpec, 40] + custom_obj: Annotated[InputToCustomObjSpec, 50] + token: Annotated[InputTokenSpec, 70] + constant_input: Annotated[InputToConstantInputSpec, 60] + + +@dataclass +class UserOutputSpec: + arg: Annotated[Argument, 10] + + +@dataclass +class LossOutputSpec: + arg: Annotated[TensorArgument, 10] + + +@dataclass +class BufferMutationSpec: + arg: Annotated[TensorArgument, 10] + buffer_name: Annotated[str, 20] + + +@dataclass +class ParameterMutationSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class GradientToParameterSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class GradientToUserInputSpec: + arg: Annotated[TensorArgument, 10] + user_input_name: Annotated[str, 20] + + +@dataclass +class UserInputMutationSpec: + arg: Annotated[TensorArgument, 10] + user_input_name: Annotated[str, 20] + + +@dataclass +class OutputTokenSpec: + arg: Annotated[TokenArgument, 10] + + +@_union_dataclass +class OutputSpec(_Union): + user_output: Annotated[UserOutputSpec, 10] + loss_output: Annotated[LossOutputSpec, 20] + buffer_mutation: Annotated[BufferMutationSpec, 30] + gradient_to_parameter: Annotated[GradientToParameterSpec, 40] + gradient_to_user_input: Annotated[GradientToUserInputSpec, 50] + user_input_mutation: Annotated[UserInputMutationSpec, 60] + token: Annotated[OutputTokenSpec, 70] + parameter_mutation: Annotated[ParameterMutationSpec, 80] + + +@dataclass +class GraphSignature: + input_specs: Annotated[list[InputSpec], 10] + output_specs: Annotated[list[OutputSpec], 20] + + +@dataclass +class RangeConstraint: + min_val: Annotated[Optional[int], 10] + max_val: Annotated[Optional[int], 20] + + +@dataclass +class ModuleCallSignature: + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] + + # These are serialized by calling pytree.treespec_loads + # And deserialized by calling pytree.treespec_dumps + in_spec: Annotated[str, 30] + out_spec: Annotated[str, 40] + + # This field is used to prettify the graph placeholders + # after we Ser/Der and retrace + forward_arg_names: Annotated[Optional[list[str]], 50] = None + + +@dataclass +class ModuleCallEntry: + fqn: Annotated[str, 10] + signature: Annotated[Optional[ModuleCallSignature], 30] = None + + +@dataclass +class NamedTupleDef: + field_names: Annotated[list[str], 10] + + +@dataclass +class GraphModule: + graph: Annotated[Graph, 10] + signature: Annotated[GraphSignature, 50] + # This is used for unflattening, by tracking the calling structure of all of + # the modules in order to unflatten the modules back to the eager calling + # conventions. + module_call_graph: Annotated[list[ModuleCallEntry], 60] + metadata: Annotated[dict[str, str], 40] = field(default_factory=dict) + # Mapping of namedtuple types to namedtuple field names, used for BC + treespec_namedtuple_fields: Annotated[dict[str, NamedTupleDef], 70] = field( + default_factory=dict + ) + + +# Invariant: Every time a change is made to the schema, one of the versions +# should be updated. +@dataclass +class SchemaVersion: + major: Annotated[ + int, 10 + ] # Major version number is bumped every time a breaking change is made. + minor: Annotated[ + int, 20 + ] # Minor version number is bumped when a compatible change is made. + + +@dataclass +class ExportedProgram: + graph_module: Annotated[GraphModule, 10] + # Key is the opset namespace (ex. aten), and value is the version number + opset_version: Annotated[dict[str, int], 20] + range_constraints: Annotated[dict[str, RangeConstraint], 30] + schema_version: Annotated[SchemaVersion, 60] + verifiers: Annotated[list[str], 70] = field(default_factory=list) + torch_version: Annotated[str, 80] = "<=2.4" + guards_code: Annotated[list[str], 90] = field(default_factory=list) + + +######################################################################### +# Container types for inference tasks, not being used directly for export. +######################################################################### + + +# The metadata for payload saved in PT2 archive. +# payload includes params, buffers, tensor constants, and custom objects. +@dataclass +class PayloadMeta: + # the path of the payload in the archive file, e.g. "weight_0" + path_name: Annotated[str, 10] + is_param: Annotated[bool, 20] + # whether the payload is serialized using pickle. + # Only custom objects and tensor subclasses that are not fake tensors + # are serialized using pickle. + use_pickle: Annotated[bool, 30] + # Custom Objects don't have tensor_meta and will be serialized using pickle + tensor_meta: Annotated[Optional[TensorMeta], 40] + + +# The mapping from payload FQN to its metadata. +@dataclass +class PayloadConfig: + config: Annotated[dict[str, PayloadMeta], 10] + + +# +# The structure is used to serialize instances of AOTInductorModel to pass +# them from the publishing pipeline to the predictor. +# +# All new fields should be marked as optional. +# +@dataclass +class AOTInductorModelPickleData: + # Base name of an associated .so AOTInductor library. Typically looks like: + # "abc.so". + library_basename: Annotated[str, 1] + + # AOTInductor engine input names. + input_names: Annotated[list[str], 2] + + # AOTInductor engine output names. + output_names: Annotated[list[str], 3] + + # These fields tell whether floating point inputs/outputs should be converted to + # a certain type. If None, the dtypes that the AOTInductor engine inferred from the sample + # inputs are used. + floating_point_input_dtype: Annotated[Optional[int], 4] = None + floating_point_output_dtype: Annotated[Optional[int], 5] = None + + # Whether AOTInductor runtime is for CPU. + aot_inductor_model_is_cpu: Annotated[Optional[bool], 6] = None + + +@dataclass +class ExternKernelNode: + # name is not the unique identifier of the node + name: Annotated[str, 10] + node: Annotated[Node, 20] + + +@dataclass +class ExternKernelNodes: + nodes: Annotated[list[ExternKernelNode], 10] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema.yaml b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f13741416cb35c4a6ac482c9f95c8d87a61e9d7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema.yaml @@ -0,0 +1,559 @@ +# @generated by update_schema.py +# checksum<> +AOTInductorModelPickleData: + kind: struct + fields: + library_basename: + type: str + input_names: + type: List[str] + output_names: + type: List[str] + floating_point_input_dtype: + type: Optional[int] + default: None + floating_point_output_dtype: + type: Optional[int] + default: None + aot_inductor_model_is_cpu: + type: Optional[bool] + default: None +Argument: + kind: union + fields: + as_none: + type: bool + as_tensor: + type: TensorArgument + as_tensors: + type: List[TensorArgument] + as_int: + type: int + as_ints: + type: List[int] + as_float: + type: float + as_floats: + type: List[float] + as_string: + type: str + as_strings: + type: List[str] + as_sym_int: + type: SymIntArgument + as_sym_ints: + type: List[SymIntArgument] + as_scalar_type: + type: ScalarType + as_memory_format: + type: MemoryFormat + as_layout: + type: Layout + as_device: + type: Device + as_bool: + type: bool + as_bools: + type: List[bool] + as_sym_bool: + type: SymBoolArgument + as_sym_bools: + type: List[SymBoolArgument] + as_graph: + type: GraphArgument + as_optional_tensors: + type: List[OptionalTensorArgument] + as_custom_obj: + type: CustomObjArgument + as_operator: + type: str + as_sym_float: + type: SymFloatArgument + as_sym_floats: + type: List[SymFloatArgument] + as_optional_tensor: + type: OptionalTensorArgument + as_complex: + type: ComplexValue + as_int_lists: + type: List[List[int]] + as_string_to_argument: + type: Dict[str, Argument] +ArgumentKind: + kind: enum + fields: + UNKNOWN: 0 + POSITIONAL: 1 + KEYWORD: 2 +BufferMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str +ComplexValue: + kind: struct + fields: + real: + type: float + imag: + type: float +ConstantValue: + kind: union + fields: + as_none: + type: bool + as_int: + type: int + as_float: + type: float + as_string: + type: str + as_bool: + type: bool +CustomObjArgument: + kind: struct + fields: + name: + type: str + class_fqn: + type: str +Device: + kind: struct + fields: + type: + type: str + index: + type: Optional[int] + default: None +ExportedProgram: + kind: struct + fields: + graph_module: + type: GraphModule + opset_version: + type: Dict[str, int] + range_constraints: + type: Dict[str, RangeConstraint] + schema_version: + type: SchemaVersion + verifiers: + type: List[str] + default: '[]' + torch_version: + type: str + default: <=2.4 + guards_code: + type: List[str] + default: '[]' +ExternKernelNode: + kind: struct + fields: + name: + type: str + node: + type: Node +ExternKernelNodes: + kind: struct + fields: + nodes: + type: List[ExternKernelNode] +GradientToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +GradientToUserInputSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +Graph: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + nodes: + type: List[Node] + tensor_values: + type: Dict[str, TensorMeta] + sym_int_values: + type: Dict[str, SymInt] + sym_bool_values: + type: Dict[str, SymBool] + is_single_tensor_return: + type: bool + default: 'False' + custom_obj_values: + type: Dict[str, CustomObjArgument] + default: '{}' + sym_float_values: + type: Dict[str, SymFloat] + default: '{}' +GraphArgument: + kind: struct + fields: + name: + type: str + graph: + type: Graph +GraphModule: + kind: struct + fields: + graph: + type: Graph + signature: + type: GraphSignature + module_call_graph: + type: List[ModuleCallEntry] + metadata: + type: Dict[str, str] + default: '{}' + treespec_namedtuple_fields: + type: Dict[str, NamedTupleDef] + default: '{}' +GraphSignature: + kind: struct + fields: + input_specs: + type: List[InputSpec] + output_specs: + type: List[OutputSpec] +InputSpec: + kind: union + fields: + user_input: + type: UserInputSpec + parameter: + type: InputToParameterSpec + buffer: + type: InputToBufferSpec + tensor_constant: + type: InputToTensorConstantSpec + custom_obj: + type: InputToCustomObjSpec + token: + type: InputTokenSpec + constant_input: + type: InputToConstantInputSpec +InputToBufferSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str + persistent: + type: bool +InputToConstantInputSpec: + kind: struct + fields: + name: + type: str + value: + type: ConstantValue +InputToCustomObjSpec: + kind: struct + fields: + arg: + type: CustomObjArgument + custom_obj_name: + type: str +InputToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +InputToTensorConstantSpec: + kind: struct + fields: + arg: + type: TensorArgument + tensor_constant_name: + type: str +InputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +Layout: + kind: enum + fields: + Unknown: 0 + SparseCoo: 1 + SparseCsr: 2 + SparseCsc: 3 + SparseBsr: 4 + SparseBsc: 5 + _mkldnn: 6 + Strided: 7 +LossOutputSpec: + kind: struct + fields: + arg: + type: TensorArgument +MemoryFormat: + kind: enum + fields: + Unknown: 0 + ContiguousFormat: 1 + ChannelsLast: 2 + ChannelsLast3d: 3 + PreserveFormat: 4 +ModuleCallEntry: + kind: struct + fields: + fqn: + type: str + signature: + type: Optional[ModuleCallSignature] + default: None +ModuleCallSignature: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + in_spec: + type: str + out_spec: + type: str + forward_arg_names: + type: Optional[List[str]] + default: None +NamedArgument: + kind: struct + fields: + name: + type: str + arg: + type: Argument + kind: + type: Optional[ArgumentKind] + default: None +NamedTupleDef: + kind: struct + fields: + field_names: + type: List[str] +Node: + kind: struct + fields: + target: + type: str + inputs: + type: List[NamedArgument] + outputs: + type: List[Argument] + metadata: + type: Dict[str, str] + is_hop_single_tensor_return: + type: Optional[bool] + default: None +OptionalTensorArgument: + kind: union + fields: + as_tensor: + type: TensorArgument + as_none: + type: bool +OutputSpec: + kind: union + fields: + user_output: + type: UserOutputSpec + loss_output: + type: LossOutputSpec + buffer_mutation: + type: BufferMutationSpec + gradient_to_parameter: + type: GradientToParameterSpec + gradient_to_user_input: + type: GradientToUserInputSpec + user_input_mutation: + type: UserInputMutationSpec + token: + type: OutputTokenSpec + parameter_mutation: + type: ParameterMutationSpec +OutputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +ParameterMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +PayloadConfig: + kind: struct + fields: + config: + type: Dict[str, PayloadMeta] +PayloadMeta: + kind: struct + fields: + path_name: + type: str + is_param: + type: bool + use_pickle: + type: bool + tensor_meta: + type: Optional[TensorMeta] +RangeConstraint: + kind: struct + fields: + min_val: + type: Optional[int] + max_val: + type: Optional[int] +ScalarType: + kind: enum + fields: + UNKNOWN: 0 + BYTE: 1 + CHAR: 2 + SHORT: 3 + INT: 4 + LONG: 5 + HALF: 6 + FLOAT: 7 + DOUBLE: 8 + COMPLEXHALF: 9 + COMPLEXFLOAT: 10 + COMPLEXDOUBLE: 11 + BOOL: 12 + BFLOAT16: 13 + UINT16: 28 + FLOAT8E4M3FN: 29 + FLOAT8E5M2: 30 + FLOAT8E4M3FNUZ: 31 + FLOAT8E5M2FNUZ: 32 +SchemaVersion: + kind: struct + fields: + major: + type: int + minor: + type: int +SymBool: + kind: union + fields: + as_expr: + type: SymExpr + as_bool: + type: bool +SymBoolArgument: + kind: union + fields: + as_name: + type: str + as_bool: + type: bool +SymExpr: + kind: struct + fields: + expr_str: + type: str + hint: + type: Optional[SymExprHint] + default: None +SymExprHint: + kind: union + fields: + as_int: + type: int + as_bool: + type: bool + as_float: + type: float +SymFloat: + kind: union + fields: + as_expr: + type: SymExpr + as_float: + type: float +SymFloatArgument: + kind: union + fields: + as_name: + type: str + as_float: + type: float +SymInt: + kind: union + fields: + as_expr: + type: SymExpr + as_int: + type: int +SymIntArgument: + kind: union + fields: + as_name: + type: str + as_int: + type: int +TensorArgument: + kind: struct + fields: + name: + type: str +TensorMeta: + kind: struct + fields: + dtype: + type: ScalarType + sizes: + type: List[SymInt] + requires_grad: + type: bool + device: + type: Device + strides: + type: List[SymInt] + storage_offset: + type: SymInt + layout: + type: Layout +TokenArgument: + kind: struct + fields: + name: + type: str +UserInputMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +UserInputSpec: + kind: struct + fields: + arg: + type: Argument +UserOutputSpec: + kind: struct + fields: + arg: + type: Argument +SCHEMA_VERSION: +- 8 +- 15 +TREESPEC_VERSION: 1 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema_check.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema_check.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec1fdb9026b9e2f2dec6d9f13ca0d6246904f3a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/schema_check.py @@ -0,0 +1,741 @@ +# mypy: allow-untyped-defs +import dataclasses +import hashlib +import inspect +import re +import typing +from enum import IntEnum +from typing import Annotated, Any, ForwardRef, Optional, Union + +from torch._export.serde import schema +from torch._export.serde.union import _Union + + +class SchemaUpdateError(Exception): + pass + + +def _check(x, msg): + if not x: + raise SchemaUpdateError(msg) + + +_CPP_TYPE_MAP = { + str: "std::string", + int: "int64_t", + float: "F64", + bool: "bool", +} + +_THRIFT_TYPE_MAP = { + str: "string", + int: "i64", + float: "double", + bool: "bool", +} + + +def _staged_schema(): + yaml_ret: dict[str, Any] = {} + defs = {} + cpp_enum_defs: dict[str, str] = {} + cpp_class_defs: dict[str, str] = {} + cpp_type_decls: list[str] = [] + cpp_json_defs: list[str] = [] + thrift_enum_defs: list[str] = [] + thrift_type_defs: dict[str, str] = {} + + def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + def dump_type(t, level: int) -> tuple[str, str, str]: + if getattr(t, "__name__", None) in cpp_enum_defs: + return t.__name__, "int64_t", t.__name__ + elif t in _CPP_TYPE_MAP: + return (t.__name__, _CPP_TYPE_MAP[t], _THRIFT_TYPE_MAP[t]) + elif isinstance(t, str): + assert t in defs + assert t not in cpp_enum_defs + assert "[" not in t + return t, f"ForwardRef<{t}>", t + elif isinstance(t, ForwardRef): + return ( + t.__forward_arg__, + f"ForwardRef<{t.__forward_arg__}>", + t.__forward_arg__, + ) + elif o := typing.get_origin(t): + # Lemme know if there's a better way to do this. + if o is list: + yaml_head, cpp_head, thrift_head, thrift_tail = ( + "List", + "std::vector", + "list<", + ">", + ) + elif o is dict: + yaml_head, cpp_head, thrift_head, thrift_tail = ( + "Dict", + "std::unordered_map", + "map<", + ">", + ) + elif o == Union: + assert level == 0, "Optional is only supported at the top level." + args = typing.get_args(t) + assert len(args) == 2 and args[1] is type(None) + yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1) + return ( + f"Optional[{yaml_type}]", + f"std::optional<{cpp_type}>", + f"optional {thrift_type}", + ) + elif o is Annotated: + return dump_type(t.__origin__, level) + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + yaml_arg_types, cpp_arg_types, thrift_arg_types = zip( + *[dump_type(x, level + 1) for x in typing.get_args(t)] + ) + return ( + (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), + (f"{cpp_head}<{', '.join(cpp_arg_types)}>"), + f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}", + ) + elif isinstance(t, type): + return (t.__name__, t.__name__, t.__name__) + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + + def dump_cpp_value(v) -> str: + if v is None: + return "std::nullopt" + elif v is True: + return "true" + elif v is False: + return "false" + elif v == {}: + return "{}" + elif v == []: + return "{}" + elif v == (): + return "{}" + elif isinstance(v, str): + return f'"{v}"' + else: + raise AssertionError( + f"Default value {v} is not supported yet in export schema." + ) + + def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: + t, cpp_type, thrift_type = dump_type(f.type, 0) + ret = {"type": t} + cpp_default: Optional[str] = None + assert typing.get_origin(f.type) is Annotated, ( + f"Field {f.name} must be annotated with an integer id." + ) + thrift_id = f.type.__metadata__[0] + assert type(thrift_id) is int, ( + f"Field {f.name} must be annotated with an integer id." + ) + + value = dataclasses.MISSING + if f.default is not dataclasses.MISSING: + value = f.default + elif f.default_factory is not dataclasses.MISSING: + value = f.default_factory() + + if value is not dataclasses.MISSING: + default = str(value) + ret["default"] = default + cpp_default = dump_cpp_value(value) + + if t.startswith("Optional[") and value is not None: + raise AssertionError( + f"Optional field {ty.__name__}.{f.name} must have default value to be None." + ) + + return ret, cpp_type, cpp_default, thrift_type, thrift_id + + yaml_ret = {} + cpp_ret = {} + thrift_ret = {} + thrift_ids = set() + for f in dataclasses.fields(ty): + yaml_res, cpp_type, cpp_default, thrift_type, thrift_id = dump_field(f) + yaml_ret[f.name] = yaml_res + cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default} + thrift_ret[f.name] = {"thrift_type": thrift_type, "thrift_id": thrift_id} + if thrift_id in thrift_ids: + raise AssertionError( + f"Duplicate thrift id {thrift_id} for field {f.name} in {ty.__name__}." + ) + thrift_ids.add(thrift_id) + return yaml_ret, cpp_ret, thrift_ret + + def _handle_int_enum(name, ty): + yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + cpp_enum_defs[name] = f""" +enum class {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}}; + +inline std::string_view printEnum(const {name}& e) {{ + switch (e) {{ +{chr(10).join([f" case {name}::{x.name}: return {chr(34)}{x.name}{chr(34)};" for x in ty])} + default: + throw std::runtime_error("Unknown enum value"); + }} +}} + +inline void parseEnum(std::string_view s, {name}& t) {{ +{chr(10).join([f" if (s == {chr(34)}{x.name}{chr(34)}) {{ t = {name}::{x.name}; return; }}" for x in ty])} + throw std::runtime_error("Unknown enum value: " + std::string{{s}}); +}} +""" + thrift_enum_defs.append( + f""" +enum {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}} +""" + ) + + def _handle_struct(name, ty): + fields, cpp_fields, thrift_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "struct", "fields": fields} + field_decls = "\n".join( + f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};" + for name, f in cpp_fields.items() + ) + + def accessor(name, ty): + type_name = fields[name]["type"] + if type_name in cpp_enum_defs: + return f""" + {type_name} get_{name}() const {{ + return static_cast<{type_name}>({name}); + }} + + void set_{name}({type_name} def) {{ + {name} = static_cast(def); + }} +""" + return f""" + const {ty}& get_{name}() const {{ + return {name}; + }} + + void set_{name}({ty} def) {{ + {name} = std::move(def); + }} +""" + + to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)" + to_json_def = f"""{{ +{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])} +}} +""" + from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)" + + from_json_def = f"""{{ + {name} nlohmann_json_default_obj; +{ + chr(10).join( + [ + f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' + for name, f in cpp_fields.items() + ] + ) + } +}} +""" + cpp_class_defs[name] = f""" +class {name} {{ + private: +{field_decls} + + public: +{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])} + friend {to_json_decl}; + friend {from_json_decl}; +}}; +""" + cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}") + cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}") + cpp_type_decls.append(f"class {name};") + + thrift_type_defs[name] = f""" +struct {name} {{ +{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} +}}""" + + def _handle_union(name, ty): + fields, cpp_fields, thrift_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "union", "fields": fields} + + def accessor(name, ty, idx): + return f""" + const {ty}& get_{name}() const {{ + return std::get<{idx + 1}>(variant_); + }} + + void set_{name}({ty} def) {{ + variant_.emplace<{idx + 1}>(std::move(def)); + tag_ = Tag::{name.upper()}; + }} +""" + + to_json_branches = "".join( + [ + f""" + if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{ + nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}(); + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + from_json_branches = "".join( + [ + f""" + if (nlohmann_json_j.contains("{name}")) {{ + nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>()); + nlohmann_json_t.tag_ = Tag::{name.upper()}; + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + + cpp_class_defs[name] = f""" +class {name} {{ + struct Void {{}}; + + public: + enum class Tag {{ + {", ".join([name.upper() for name in cpp_fields])} + }}; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const {{ + return tag_; + }} +{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])} + friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{ +{to_json_branches} + }} + + friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{ +{from_json_branches} + }} +}}; + +inline std::string_view printEnum(const {name}::Tag& e) {{ + switch (e) {{ +{chr(10).join([f" case {name}::Tag::{x.upper()}: return {chr(34)}{x.upper()}{chr(34)};" for x in cpp_fields])} + default: + throw std::runtime_error("Unknown enum value"); + }} +}} + +inline void parseEnum(std::string_view s, {name}::Tag& t) {{ +{chr(10).join([f" if (s == {chr(34)}{x.upper()}{chr(34)}) {{ t = {name}::Tag::{x.upper()}; return; }}" for x in cpp_fields])} + throw std::runtime_error("Unknown enum value: " + std::string{{s}}); +}} + +""" + cpp_type_decls.append(f"class {name};") + + thrift_type_defs[name] = f""" +union {name} {{ +{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} +}}""" + + for name in dir(schema): + if name.startswith("_"): + continue + + value = getattr(schema, name) + + if hasattr(value, "__module__") and value.__module__ != schema.__name__: + continue + + defs[name] = value + + class_ordering = {} + for name, value in defs.items(): + if isinstance(value, type): + if issubclass(value, IntEnum): + _handle_int_enum(name, value) + elif dataclasses.is_dataclass(value): + class_ordering[name] = inspect.findsource(value)[1] + if issubclass(value, _Union): + _handle_union(name, value) + else: + _handle_struct(name, value) + else: + raise AssertionError(f"Unknown schema type {name}: {value}") + elif isinstance(value, (int, tuple)): + assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") + else: + raise AssertionError(f"Unknown variable {name}: {value}") + + yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"]) + yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert yaml_ret["TREESPEC_VERSION"] > 0 + + cpp_header = f""" +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{ +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END }} +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> {{ + static void to_json(json& j, const std::optional& opt) {{ + if (opt == std::nullopt) {{ + j = nullptr; + }} else {{ + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + }} + }} + + static void from_json(const json& j, std::optional& opt) {{ + if (j.is_null()) {{ + opt = std::nullopt; + }} else {{ + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + }} + }} +}}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch {{ +namespace _export {{ + +template +class ForwardRef {{ + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {{}} + ForwardRef(ForwardRef&&); + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {{}} + ForwardRef& operator=(ForwardRef&&); + ForwardRef& operator=(const ForwardRef& other) {{ + ptr_ = std::make_unique(*other.ptr_); + return *this; + }} + ~ForwardRef(); + const T& operator*() const {{ + return *ptr_; + }} + + const T* operator->() const {{ + return ptr_.get(); + }} + + void emplace(T&& t) {{ + ptr_ = std::make_unique(std::move(t)); + }} + + private: + std::unique_ptr ptr_; +}}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) {{ + j = *p; +}} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) {{ + p.emplace(j.template get()); +}} + +class F64 {{ + public: + double get() const {{ + return value_; + }} + + void set(double value) {{ + value_ = value; + }} + + private: + double value_; +}}; + +inline void to_json(nlohmann::json& j, const F64& f) {{ + if (std::isinf(f.get())) {{ + j = "Infinity"; + }} else if (std::isinf(-f.get())) {{ + j = "-Infinity"; + }} else if (std::isnan(f.get())) {{ + j = "NaN"; + }} else {{ + j = f.get(); + }} +}} + +inline void from_json(const nlohmann::json& j, F64& f) {{ + if (j == "Infinity") {{ + f.set(std::numeric_limits::infinity()); + }} else if (j == "-Infinity") {{ + f.set(-std::numeric_limits::infinity()); + }} else if (j == "NaN") {{ + f.set(std::numeric_limits::quiet_NaN()); + }} else {{ + f.set(j.get()); + }} +}} + +{chr(10).join(cpp_type_decls)} +{"".join(cpp_enum_defs.values())} +{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +{chr(10).join(cpp_json_defs)} + +template ForwardRef::ForwardRef(ForwardRef&&) = default; +template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +template ForwardRef::~ForwardRef() = default; +}} // namespace _export +}} // namespace torch +""" + thrift_schema = f""" +namespace py3 torch._export +namespace cpp2 torch._export.schema +{chr(10).join(thrift_enum_defs)} +{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +""" + return yaml_ret, cpp_header, thrift_schema + + +def _diff_schema(dst, src): + additions = {key: src[key] for key in src.keys() - dst.keys()} + subtractions = {key: dst[key] for key in dst.keys() - src.keys()} + + common_keys = src.keys() & dst.keys() + + versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} + common_keys -= versions + + for key in common_keys: + src_kind = src[key]["kind"] + src_fields = src[key]["fields"] + dst_kind = dst[key]["kind"] + dst_fields = dst[key]["fields"] + _check( + src_kind == dst_kind, + f"Type {key} changed kind from {dst_kind} to {src_kind}", + ) + assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) + added_fields = { + key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() + } + subtracted_fields = { + key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() + } + common_fields = src_fields.keys() & dst_fields.keys() + + for field in common_fields: + src_field = src_fields[field] + dst_field = dst_fields[field] + if src_kind == "struct": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + if "default" in src_field and "default" not in dst_field: + added_fields[field] = {} + added_fields[field]["default"] = src_field["default"] + if "default" not in src_field and "default" in dst_field: + subtracted_fields[field] = {} + subtracted_fields[field]["default"] = dst_field["default"] + elif src_kind == "enum": + _check( + src_field == dst_field, + f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", + ) + elif src_kind == "union": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + else: + raise AssertionError(f"Unknown kind {src_kind}: {key}") + if len(added_fields) > 0: + assert key not in additions + additions[key] = {} + additions[key]["fields"] = added_fields + if len(subtracted_fields) > 0: + assert key not in subtractions + subtractions[key] = {} + subtractions[key]["fields"] = subtracted_fields + + return additions, subtractions + + +def _hash_content(s: str): + return hashlib.sha256(s.strip().encode("utf-8")).hexdigest() + + +@dataclasses.dataclass +class _Commit: + result: dict[str, Any] + checksum_next: str + yaml_path: str + additions: dict[str, Any] + subtractions: dict[str, Any] + base: dict[str, Any] + checksum_head: Optional[str] + cpp_header: str + cpp_header_path: str + thrift_checksum_head: Optional[str] + thrift_checksum_real: Optional[str] + thrift_checksum_next: str + thrift_schema: str + thrift_schema_path: str + + +def update_schema(): + import importlib.resources + + # pyrefly: ignore [bad-argument-type] + if importlib.resources.is_resource(__package__, "schema.yaml"): + # pyrefly: ignore [bad-argument-type] + content = importlib.resources.read_text(__package__, "schema.yaml") + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) + _check(match is not None, "checksum not found in schema.yaml") + assert match is not None + checksum_head = match.group(1) + + thrift_content = importlib.resources.read_text( + # pyrefly: ignore [bad-argument-type] + __package__, + "export_schema.thrift", + ) + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content) + _check(match is not None, "checksum not found in export_schema.thrift") + assert match is not None + thrift_checksum_head = match.group(1) + thrift_content = thrift_content.splitlines() + assert thrift_content[0].startswith("// @" + "generated") + assert thrift_content[1].startswith("// checksum<<") + thrift_checksum_real = _hash_content("\n".join(thrift_content[2:])) + + from yaml import load, Loader + + dst = load(content, Loader=Loader) + assert isinstance(dst, dict) + else: + checksum_head = None + thrift_checksum_head = None + thrift_checksum_real = None + dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} + + src, cpp_header, thrift_schema = _staged_schema() + additions, subtractions = _diff_schema(dst, src) + # pyrefly: ignore [missing-attribute] + yaml_path = __package__.replace(".", "/") + "/schema.yaml" + # pyrefly: ignore [missing-attribute] + thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift" + torch_prefix = "torch/" + assert yaml_path.startswith(torch_prefix) # sanity check + assert thrift_schema_path.startswith(torch_prefix) # sanity check + + return _Commit( + result=src, + checksum_next=_hash_content(repr(src)), + yaml_path=yaml_path, + additions=additions, + subtractions=subtractions, + base=dst, + checksum_head=checksum_head, + cpp_header=cpp_header, + cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h", + thrift_checksum_head=thrift_checksum_head, + thrift_checksum_real=thrift_checksum_real, + thrift_checksum_next=_hash_content(thrift_schema), + thrift_schema=thrift_schema, + thrift_schema_path=thrift_schema_path, + ) + + +def check(commit: _Commit, force_unsafe: bool = False): + next_version = None + reason = "" + # Step 1: Detect major schema updates. + if len(commit.additions) > 0: + for k, v in commit.additions.items(): + if k not in commit.base: + continue + kind = commit.result[k]["kind"] + fields = v["fields"] + for f, d in fields.items(): + if kind == "struct" and "default" not in d: + reason += ( + f"Field {k}.{f} is added to schema.py without a default value as an incompatible change " + + "which requires major version bump.\n" + ) + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + if k not in commit.result: + continue + for f in v["fields"]: + reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if force_unsafe: + reason += "--force-unsafe is used." + next_version = commit.result["SCHEMA_VERSION"] + else: + # Step 2: Detect minor schema updates. + if next_version is None and len(commit.additions) > 0: + for k, v in commit.additions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is added to schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + if next_version is None and len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is removed from schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + + return next_version, reason diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/serialize.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..c64aaff9ae1f2b693c753a3b26fa94462cfca870 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/serialize.py @@ -0,0 +1,3936 @@ +# mypy: allow-untyped-defs +import base64 +import copy +import copyreg +import dataclasses +import heapq +import inspect +import io +import json +import keyword +import logging +import math +import operator +import re +import traceback +import typing +from collections import namedtuple, OrderedDict +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Annotated, Any, cast, final, Optional, Union + +import sympy + +import torch +import torch.export.exported_program as ep +from torch._export.non_strict_utils import _enable_graph_inputs_of_type_nn_module +from torch._export.verifier import load_verifier +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.experimental import symbolic_shapes +from torch.utils import _pytree as pytree +from torch.utils._pytree import treespec_dumps, treespec_loads +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.symbol import prefix_str, SymT +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._traceback import CapturedTraceback +from torch.utils._triton import has_triton + +from ..utils import remove_proxy_from_state_dict +from . import schema +from .schema import ( # type: ignore[attr-defined] + Argument, + ArgumentKind, + BufferMutationSpec, + ComplexValue, + ConstantValue, + CustomObjArgument, + Device, + ExportedProgram, + GradientToParameterSpec, + GradientToUserInputSpec, + Graph, + GraphArgument, + GraphModule, + GraphSignature, + InputSpec, + InputToBufferSpec, + InputToConstantInputSpec, + InputToCustomObjSpec, + InputTokenSpec, + InputToParameterSpec, + InputToTensorConstantSpec, + Layout, + LossOutputSpec, + MemoryFormat, + ModuleCallEntry, + ModuleCallSignature, + NamedArgument, + NamedTupleDef, + Node, + OptionalTensorArgument, + OutputSpec, + OutputTokenSpec, + ParameterMutationSpec, + RangeConstraint, + ScalarType, + SCHEMA_VERSION, + SchemaVersion, + SymBool, + SymBoolArgument, + SymExpr, + SymExprHint, + SymFloat, + SymFloatArgument, + SymInt, + SymIntArgument, + TensorArgument, + TensorMeta, + TokenArgument, + TREESPEC_VERSION, + UserInputMutationSpec, + UserInputSpec, + UserOutputSpec, +) +from .union import _Union + + +__all__ = [ + "serialize", + "GraphModuleSerializer", + "ExportedProgramSerializer", + "GraphModuleDeserializer", + "ExportedProgramDeserializer", +] + +log = logging.getLogger(__name__) + + +class SerializeError(RuntimeError): + pass + + +def _reverse_map(d: dict[Any, Enum]): + return {v.value: k for k, v in d.items()} + + +MetaType = Union[ + FakeTensor, + int, + torch.SymInt, + float, + torch.SymFloat, + bool, + torch.SymBool, + ep.CustomObjArgument, +] + +DEFAULT_PICKLE_PROTOCOL = 2 + +ST_DELIMITER = ";" + +_TORCH_TO_SERIALIZE_DTYPE = { + torch.uint8: ScalarType.BYTE, + torch.int8: ScalarType.CHAR, + torch.uint16: ScalarType.UINT16, + torch.int16: ScalarType.SHORT, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.float64: ScalarType.DOUBLE, + torch.complex32: ScalarType.COMPLEXHALF, + torch.complex64: ScalarType.COMPLEXFLOAT, + torch.complex128: ScalarType.COMPLEXDOUBLE, + torch.bool: ScalarType.BOOL, + torch.bfloat16: ScalarType.BFLOAT16, + torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, + torch.float8_e5m2: ScalarType.FLOAT8E5M2, + torch.float8_e4m3fnuz: ScalarType.FLOAT8E4M3FNUZ, + torch.float8_e5m2fnuz: ScalarType.FLOAT8E5M2FNUZ, +} + + +_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_LAYOUT = { + torch.sparse_coo: Layout.SparseCoo, + torch.sparse_csr: Layout.SparseCsr, + torch.sparse_csc: Layout.SparseCsc, + torch.sparse_bsr: Layout.SparseBsr, + torch.sparse_bsc: Layout.SparseBsc, + torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined] + torch.strided: Layout.Strided, +} + + +_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_MEMORY_FORMAT = { + torch.contiguous_format: MemoryFormat.ContiguousFormat, + torch.channels_last: MemoryFormat.ChannelsLast, + torch.channels_last_3d: MemoryFormat.ChannelsLast3d, + torch.preserve_format: MemoryFormat.PreserveFormat, +} + + +_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type] + +_SYM_OPS = { + operator.eq, + operator.ne, + operator.le, + operator.ge, + operator.lt, + operator.gt, + operator.neg, + operator.pos, + operator.and_, + operator.or_, + math.trunc, + torch.sym_not, + operator.mul, + operator.add, + operator.sub, + operator.floordiv, + operator.mod, + operator.pow, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_sqrt, + operator.truediv, + operator.and_, +} + + +assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_OPS) + + +@dataclass +class SerializedArtifact: + exported_program: bytes + state_dict: bytes + constants: bytes + example_inputs: bytes + + +@dataclass +class _SerializedProgram: + exported_program: ExportedProgram + state_dict: bytes + constants: bytes + example_inputs: bytes + + +class LazyMap(dict): + """ + Dictionary class for deferred instantiation of node metadata values. + Purpose is to avoid creation of symbolic-shape tensors before relevant shape guards are parsed. + """ + + def __init__(self): + self.map = {} + self.evaluated = set() + + def __setitem__(self, k, v): + self.map[k] = v + + def __getitem__(self, k): + out = self.map[k] + if k in self.evaluated: + return out + self.evaluated.add(k) + self.map[k] = out() + return self.map[k] + + def __repr__(self): + return self.map.__repr__() + + +def deserialize_device(d: Device) -> torch.device: + if d.index is None: + return torch.device(type=d.type) # type: ignore[call-overload] + return torch.device(type=d.type, index=d.index) + + +def deserialize_size(sizes: Sequence[SymInt]) -> tuple[int, ...]: + for sym_int_size in sizes: + assert sym_int_size.type == "as_int", ( + f"Only as_int is supported, got {sym_int_size.type}" + ) + return tuple(sym_int_size.as_int for sym_int_size in sizes) + + +def deserialize_stride(strides: Sequence[SymInt]) -> tuple[int, ...]: + for sym_int_stride in strides: + assert sym_int_stride.type == "as_int", ( + f"Only as_int is supported, got {sym_int_stride.type}" + ) + return tuple(sym_int_stride.as_int for sym_int_stride in strides) + + +def deserialize_scalar_type(st: ScalarType) -> torch.dtype: + return _SERIALIZE_TO_TORCH_DTYPE[st] + + +def deserialize_storage_offset(offset: SymInt) -> int: + assert offset.type == "as_int", f"Only as_int is supported, got {offset.type}" + return offset.as_int + + +def _print_sympy(s: Union[torch.SymInt, torch.SymBool, torch.SymFloat, sympy.Expr]): + if isinstance(s, (torch.SymInt, torch.SymBool, torch.SymFloat)): + s = s.node.expr + return sympy.printing.repr.srepr(s) + + +def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: + if isinstance(s, (torch.SymInt, sympy.Symbol, int)): + if symbolic_shapes.is_concrete_int(s): + return SymInt.create(as_int=int(s)) + else: + assert isinstance(s, (torch.SymInt, sympy.Symbol)) + if s.node.hint is None: + return SymInt.create(as_expr=SymExpr(_print_sympy(s))) + else: + return SymInt.create( + as_expr=SymExpr( + _print_sympy(s), + hint=SymExprHint.create(as_int=s.node.hint), + ) + ) + else: + raise SerializeError( + f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_float(s: Union[float, torch.SymFloat]) -> SymFloat: + if isinstance(s, (torch.SymFloat, sympy.Symbol, float)): + if symbolic_shapes.is_concrete_float(s): + return SymFloat.create(as_float=float(s)) + else: + assert isinstance(s, (torch.SymFloat, sympy.Symbol)) + if s.node.hint is None: + return SymFloat.create(as_expr=SymExpr(_print_sympy(s))) + else: + return SymFloat.create( + as_expr=SymExpr( + _print_sympy(s), + hint=SymExprHint.create(as_float=s.node.hint), + ) + ) + else: + raise SerializeError( + f"SymFloat should be either symbol or float, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: + if isinstance(s, (torch.SymBool, bool)): + if symbolic_shapes.is_concrete_bool(s): + return SymBool.create(as_bool=bool(s)) + else: + return SymBool.create(as_expr=SymExpr(expr_str=_print_sympy(s))) + else: + raise SerializeError( + f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" + ) + + +def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: + """ + Extract a TensorMeta describing `t`. + """ + return TensorMeta( + dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype], + sizes=[serialize_sym_int(s) for s in t.shape], + requires_grad=t.requires_grad, + device=Device(type=t.device.type, index=t.device.index), + strides=[serialize_sym_int(s) for s in t.stride()], + storage_offset=serialize_sym_int(t.storage_offset()), + layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], + ) + + +_CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None + + +def _reduce_fake_tensor(fake_tensor: FakeTensor): + is_parameter = isinstance(fake_tensor, torch.nn.Parameter) + tensor_meta = serialize_tensor_meta(fake_tensor) + tensor_meta_bytes = json.dumps( + _dataclass_to_dict(tensor_meta), cls=EnumEncoder + ).encode("utf-8") + return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) + + +def _reconstruct_fake_tensor( + serialized_tensor_meta: bytes, is_parameter: bool +) -> FakeTensor: + # Deserialize the bytes into a TensorMeta + json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) + tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) + # Find the current fake mode + assert _CURRENT_DESERIALIZER is not None, ( + "Need access to current deserializer state" + ) + fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) + if is_parameter: + fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] + # pyrefly: ignore [bad-return] + return fake_tensor + + +def serialize_torch_artifact( + artifact: Optional[Any], pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL +) -> bytes: + if artifact is None: + return b"" + + assert FakeTensor not in copyreg.dispatch_table, ( + "Refusing to stomp on existing FakeTensor reducer" + ) + try: + copyreg.pickle(FakeTensor, _reduce_fake_tensor) + buffer = io.BytesIO() + # This is a workaround for backend's tensor deserialization problem: + # unpickleTensor() always create a tensor on the device where it was originally saved + # This behavior is bad for multi-gpu training, as we wish to directly load the tensor + # on the designated device. + # For now, we simply move the tensor to cpu before saving. + # TODO: this should be fixed by deserialization instead. + torch.save(artifact, buffer, pickle_protocol=pickle_protocol) + return buffer.getvalue() + finally: + del copyreg.dispatch_table[FakeTensor] + + +def deserialize_torch_artifact( + serialized: Union[dict[str, Any], tuple[Any, ...], bytes], +): + if isinstance(serialized, (dict, tuple)): + return serialized + if len(serialized) == 0: + return {} + buffer = io.BytesIO(serialized) + buffer.seek(0) + # weights_only=False as we want to load custom objects here (e.g. ScriptObject) + try: + artifact = torch.load(buffer, weights_only=True) + except Exception as e: + buffer.seek(0) + artifact = torch.load(buffer, weights_only=False) + log.warning( + "Fallback to weights_only=False succeeded. " + "Loaded object of type %s after initial failure: %s", + type(artifact), + exc_info=e, + ) + assert isinstance(artifact, (tuple, dict)) + return artifact + + +def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]: + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return None + if val in (-sympy.oo, -int_oo): + return None + if isinstance(val, sympy.Integer): + return int(val) + + # TODO: Remove this adjustment when Ed gets rid of fractional ranges + log.warning( + "Export constraints cannot be non-integer expressions. Found " + "type %s, and value %s. We will attempt to %s " + "this value.", + type(val), + val, + adjust, + ) + + if adjust == "floor": + return math.floor(val) + elif adjust == "ceil": + return math.ceil(val) + else: + raise RuntimeError(f"Got invalid adjustment {adjust}") + + +def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr: + # Convert concrete int into simple sympy Integers + if val is None: + return default + if val in [-int_oo, int_oo]: + return val + if val == math.inf: + return int_oo + if val == -math.inf: + return -int_oo + return sympy.Integer(val) + + +def _symbol_index(sym: sympy.Symbol, sym_type: SymT): + return int(str(sym)[len(prefix_str[sym_type]) :]) + + +def serialize_range_constraints( + range_constraints: dict[sympy.Symbol, ValueRanges], +) -> dict[str, RangeConstraint]: + return { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type] + _sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type] + ) + for k, v in range_constraints.items() + } + + +def _get_schema_from_target(target): + if isinstance(target, torch._ops.OpOverload): + return target._schema + elif type(target) in _serialization_registry: + return _serialization_registry[type(target)].op_schema(target) + raise RuntimeError(f"Cannot find schema for {type(target)}") + + +@dataclass +class GraphState: + inputs: list[Argument] = field(default_factory=list) + outputs: list[Argument] = field(default_factory=list) + nodes: list[Node] = field(default_factory=list) + tensor_values: dict[str, TensorMeta] = field(default_factory=dict) + sym_int_values: dict[str, SymInt] = field(default_factory=dict) + sym_bool_values: dict[str, SymBool] = field(default_factory=dict) + sym_float_values: dict[str, SymFloat] = field(default_factory=dict) + is_single_tensor_return: bool = False + custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict) + + +class Final(type): + def __new__(metacls, name, bases, classdict): + for b in bases: + if isinstance(b, Final): + raise TypeError(f"type '{b.__name__}' is not an acceptable base type") + return type.__new__(metacls, name, bases, dict(classdict)) + + +def is_metadata_matched(config, entry_metadata): + metadata_attrs = ["num_cpu_threads", "num_warps", "num_stages", "num_ctas"] + for attr in metadata_attrs: + if hasattr(config, attr) and hasattr(entry_metadata, attr): + if getattr(config, attr) != getattr(entry_metadata, attr): + return False + return True + + +def get_triton_kernel_and_cache_entry(node: torch.fx.Node): + assert ( + node.target + is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional + ) + + assert has_triton(), "triton required to serialize triton kernels" + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + assert isinstance(node.kwargs["kernel_idx"], int) + kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel( + node.kwargs["kernel_idx"] + ) + + # For Autotuner, we need to look at the underlying JITFunction's cache + # since the Autotuner itself doesn't have a cache + is_autotuner = isinstance(kernel, Autotuner) + # pyrefly: ignore [missing-attribute] + actual_kernel = kernel.fn if is_autotuner else kernel + + if hasattr(actual_kernel, "device_caches"): + caches = actual_kernel.device_caches + assert len(caches.keys()) == 1 + cache = next(iter(caches.values()))[0] + elif hasattr(actual_kernel, "cache"): + # old path, still used for cpu triton builds + caches = actual_kernel.cache + assert len(caches.keys()) == 1 + cache = next(iter(caches.values())) + else: + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"kernel caches not found for kernel {actual_kernel.__name__}" + ) + + if len(cache.keys()) == 1: + return actual_kernel, next(iter(cache.values())) + + has_constexprs = ( + isinstance(actual_kernel, JITFunction) + and hasattr(actual_kernel, "constexprs") + and len(actual_kernel.constexprs) > 0 + ) + + if has_constexprs: + constexpr_vals = {} + # pyrefly: ignore [missing-attribute] + for constexpr_idx in actual_kernel.constexprs: + # pyrefly: ignore [missing-attribute] + if constexpr_idx < len(actual_kernel.arg_names): + # pyrefly: ignore [missing-attribute] + param_name = actual_kernel.arg_names[constexpr_idx] + kwargs_dict = node.kwargs.get("kwargs", {}) + if isinstance(kwargs_dict, dict): + if param_name in kwargs_dict: + constexpr_vals[param_name] = kwargs_dict[param_name] + + expected_values = [ + # pyrefly: ignore [missing-attribute] + constexpr_vals[actual_kernel.arg_names[idx]] + # pyrefly: ignore [missing-attribute] + for idx in actual_kernel.constexprs + # pyrefly: ignore [missing-attribute] + if actual_kernel.arg_names[idx] in constexpr_vals + ] + + matching_entries = [] + for sig_key, cache_entry in cache.items(): + constexpr_matches = re.findall(r"\('constexpr',\s*([^)]+)\)", sig_key) + if constexpr_matches: + constexpr_values = [] + for match in constexpr_matches: + if match in ("True", "False"): + constexpr_values.append(match == "True") + elif "." in match or "e" in match or "E" in match: + constexpr_values.append(float(match)) + else: + constexpr_values.append(int(match)) + + if constexpr_values == expected_values: + matching_entries.append((sig_key, cache_entry)) + else: + matching_entries = list(cache.items()) + + if len(matching_entries) == 0: + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"couldn't find a kernel cache entry with metadata matching the autotuner configs for kernel {actual_kernel.__name__}. " + f"Available cache keys: {list(cache.keys())}" + ) + + if len(matching_entries) == 1: + return actual_kernel, matching_entries[0][1] + + if is_autotuner: + for _sig_key, cache_entry in matching_entries: + entry_metadata = cache_entry.metadata + # pyrefly: ignore [missing-attribute] + for config in kernel.configs: + if is_metadata_matched(config, entry_metadata): + return actual_kernel, cache_entry + + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"Multiple cache entries found for autotuned kernel {actual_kernel.__name__} " + f"{'with same constexpr values' if has_constexprs else 'with no constexpr'} " + f"and couldn't disambiguate using configs. " + ) + + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"Multiple cache entries found for non-autotuned kernel {actual_kernel.__name__} " + f"{'with same constexpr values' if has_constexprs else 'with no constexpr'}. " + f"This should not happen. Available cache keys: {[key for key, _ in matching_entries]}" + ) + + +@final +class GraphModuleSerializer(metaclass=Final): + def __init__( + self, + graph_signature: ep.ExportGraphSignature, + module_call_graph: list[ep.ModuleCallEntry], + ): + self.graph_state = GraphState() + self.graph_signature = graph_signature + self.module_call_graph = module_call_graph + self.custom_objs: dict[str, torch._C.ScriptObject] = {} + self.duplicate_getitem_nodes: dict[str, str] = {} + self.treespec_namedtuple_fields: dict[str, NamedTupleDef] = {} + + @contextmanager + def save_graph_state(self): + saved = self.graph_state + self.graph_state = GraphState() + try: + yield + finally: + self.graph_state = saved + + def handle_placeholder(self, node: torch.fx.Node): + assert node.op == "placeholder" + val = node.meta["val"] + log.debug("[handle_placeholder] %s: %s", node.name, val) + if isinstance(val, torch.Tensor): + graph_input = Argument.create( + as_tensor=self.serialize_tensor_output(node.name, val) + ) + elif isinstance(val, torch.SymInt): + graph_input = Argument.create( + as_sym_int=self.serialize_sym_int_output(node.name, val) + ) + elif isinstance(val, torch.SymFloat): + raise AssertionError("SymFloat graph input is not implemented yet.") + elif isinstance(val, (int, bool, str, float, type(None))): + graph_input = self.serialize_input(val) + elif isinstance(val, ep.CustomObjArgument): + class_fqn = val.class_fqn + graph_input = Argument.create( + as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) + ) + self.graph_state.custom_obj_values[node.name] = ( + self.serialize_script_obj_meta(val) + ) + else: + raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") + self.graph_state.inputs.append(graph_input) + + def handle_output(self, node: torch.fx.Node): + assert node.op == "output" + assert len(node.args) == 1, "FX.Node's args should have one arg" + node_args = node.args[0] + log.debug("[handle_output] %s: %s", node.name, node_args) + if isinstance(node_args, torch.fx.Node): + # For singleton tensor returns + self.graph_state.is_single_tensor_return = True + self.graph_state.outputs = [self.serialize_input(node_args)] + else: + assert isinstance(node_args, (tuple, list)) + self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args] + + def serialize_operator(self, target) -> str: + if isinstance(target, str): + return target + elif target.__module__.startswith("torch._ops"): + # TODO(zhxchen17) Maybe provide a function name helper in FX. + # From torch.fx.node._get_qualified_name + module = target.__module__.replace("torch._ops", "torch.ops") + return f"{module}.{target.__name__}" + else: # TODO(zhxchen17) Don't catch all here. + return f"{target.__module__}.{target.__name__}" + + def handle_call_function(self, node: torch.fx.Node): + assert node.op == "call_function" + meta_val = node.meta.get("val") + log.debug( + "[handle_call_function] %s: %s(%s, {%s}) -> %s", + node.name, + node.target, + node.args, + node.kwargs, + meta_val, + ) + + # getitem has been handled in the producer node, skip it here + if node.target is operator.getitem: + return + + if node.target in _SYM_OPS or ( + meta_val is not None + and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat)) + ): + assert len(node.kwargs) == 0 + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[self.serialize_output(node.name, meta_val)], + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.OpOverload): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + # TODO: create a new tensor_values here, meta might have faketensor info + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.HigherOrderOperator): + + def _is_hop_single_tensor_return(node) -> bool: + assert isinstance(node.target, torch._ops.HigherOrderOperator) + # HOP schema is not always available, so we look at node.meta["val"] + meta_val = node.meta.get("val", None) + return meta_val is not None and isinstance(meta_val, torch.Tensor) + + # Special handle serialization for aoti_call_delegate + if node.target is torch._higher_order_ops.aoti_call_delegate: + serializable_args = list(node.args) + + # AOTI lowered module is not serializable, serialize the aoti_path instead + lowered_module_name: str = node.args[0].name # type: ignore[assignment, no-untyped-def, union-attr] + assert hasattr(node.graph.owning_module, lowered_module_name) + lowered_module = getattr(node.graph.owning_module, lowered_module_name) # type: ignore[no-untyped-def] + serializable_args[0] = lowered_module.aoti_path + + # AOTI compiled graph module in node.args[0] is stateful, and will fail the verifier check + # Skip serializing original_gm as a workaround + serializable_args[1] = None + + serializable_weight_nodes = [] + if serializable_args[2] is not None and isinstance( + serializable_args[2], Iterable + ): + for weight_node in serializable_args[2]: + # skip passing custom obj into the weight arg as an hack + # The schema of weight input is a list of Tensors. + # Downstream runtime is not actively consuming the weighs arg for anything meaningful. + if isinstance(weight_node, torch.fx.Node) and isinstance( + weight_node.meta.get("val", None), ep.CustomObjArgument + ): + continue + serializable_weight_nodes.append(weight_node) + serializable_args[2] = serializable_weight_nodes + + def serialize_tensor_list_output(node): + meta_val = node.meta.get("val", None) + tensor_args = [] + for idx, meta in enumerate(meta_val): + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(serializable_args, node.kwargs), + outputs=serialize_tensor_list_output(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=False, + ) + elif ( + node.target + is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional + ): + kernel, kernel_cache_entry = get_triton_kernel_and_cache_entry(node) + kernel_cache_metadata = kernel_cache_entry.metadata + + meta_val = node.meta["val"] + assert isinstance(meta_val, dict) + + output_keys = meta_val.keys() + output_indices = [] + + constexpr_keys = {p.name for p in kernel.params if p.is_constexpr} + found_constexpr = False + args_new = () + i = 0 + + assert isinstance(node.kwargs["kwargs"], dict) + for k, v in node.kwargs["kwargs"].items(): + # don't serialize constexpr since they will + # be embedded into the binary and don't + # need to be passed around as attributes + if k in constexpr_keys: + found_constexpr = True + continue + + assert not found_constexpr, ( + "non-constexpr args found after constexpr arg(s)" + ) + + if k in output_keys: + output_indices.append(i) + args_new += (v,) # type: ignore[assignment] + i += 1 + + assert isinstance(node.kwargs["grid"], list) + + kernel_name_with_hash = ( + f"{kernel.fn.__name__}_{kernel_cache_metadata.hash}" + ) + kwargs_new = { + "name": kernel_name_with_hash, + "grid": node.kwargs["grid"][0], + "output_indices": output_indices, + "num_warps": kernel_cache_metadata.num_warps, + } + if hasattr(kernel_cache_metadata, "num_cpu_threads"): + kwargs_new["num_cpu_threads"] = ( + kernel_cache_metadata.num_cpu_threads + ) + + if hasattr(kernel_cache_metadata, "shared"): + kwargs_new["shared_memory_bytes"] = kernel_cache_metadata.shared + + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(args_new, kwargs_new), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=_is_hop_single_tensor_return(node), + ) + else: + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(node.args, node.kwargs), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=_is_hop_single_tensor_return(node), + ) + elif type(node.target) in _serialization_registry: + # Sanity check for unhandled serialization. + assert type(node.target) in _serialization_registry, ( + f"{type(node.target)} is not supported in export serialization." + ) + + handler = _serialization_registry[type(node.target)] + namespace = handler.namespace() + op_name = handler.to_op_name(node.target) + assert isinstance(namespace, str) and isinstance(op_name, str) + assert ":" not in namespace and ":" not in op_name + ex_node = Node( + target=f"#{namespace}:{op_name}", + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + metadata=self.serialize_metadata(node), + ) + else: + raise SerializeError(f"Serializing {node.target} is not supported") + + self.graph_state.nodes.append(ex_node) + + def handle_get_attr(self, node): + log.debug("[handle_get_attr] %s", node.name) + + def _output_node_at_index(self, node, index) -> Optional[torch.fx.Node]: + user_node = None + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + if user_node is None: + user_node = user + else: + # We want to deduplicate getitem nodes that are trying to + # index to the same index + self.duplicate_getitem_nodes[user.name] = user_node.name + return user_node + + def _output_node_name_at_index(self, node, index) -> str: + user_node = self._output_node_at_index(node, index) + if user_node is None: + return f"{node.name}_unused_{index}" + else: + return user_node.name + + def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]: + ret = {} + + if stack_trace := node.meta.get("stack_trace"): + ret["stack_trace"] = stack_trace + + if nn_module_stack := node.meta.get("nn_module_stack"): + + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + + assert isinstance(path, str) + assert isinstance(ty, str) + + return path + "," + ty + + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) + + if source_fn_st := node.meta.get("source_fn_stack"): + source_fn_list = [ + f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" + for source_fn in source_fn_st + ] + ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) + + if torch_fn := node.meta.get("torch_fn"): + ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn)) + + if custom := node.meta.get("custom"): + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for node {node.name} with error {e}" + ) from e + + return ret + + def serialize_script_obj_meta( + self, script_obj_meta: ep.CustomObjArgument + ) -> CustomObjArgument: + log.debug("[serialize_script_obj_meta] %s", script_obj_meta) + return CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]: + if isinstance(op, torch._ops.OpOverload): + args_names = [arg.name for arg in op._schema.arguments] + else: + assert op in _SYM_OPS + args_names = list(inspect.signature(op).parameters.keys()) + serialized_args = [] + for args_name, arg in zip(args_names, args): + serialized_args.append( + NamedArgument( + name=args_name, + arg=self.serialize_input(arg), + kind=ArgumentKind.POSITIONAL, + ) + ) + return serialized_args + + def serialize_inputs( + self, + target: Any, # torch._ops.OpOverload and other custom operator types. + args, + kwargs=None, + ) -> list[NamedArgument]: + schema = None + serialized_args = [] + + if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind): + obj = args[0] + method = args[1] + schema = target.schema(obj, method) + else: + assert isinstance( + target, (torch._ops.OpOverload, *_registered_extension_types()) + ) + schema = _get_schema_from_target(target) + assert schema is not None + kwargs = kwargs or {} + + for i, schema_arg in enumerate(schema.arguments): + if schema_arg.name in kwargs: + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input( + kwargs[schema_arg.name], schema_arg.type + ), + kind=ArgumentKind.KEYWORD, + ) + ) + elif not schema_arg.kwarg_only and i < len(args): + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(args[i], schema_arg.type), + kind=ArgumentKind.POSITIONAL, + ) + ) + else: + # We intentionally don't serialize the missing arguments + # with default values + pass + + return serialized_args + + def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]: + """ + For serializing HOO inputs since HOOs do not have a schema. + """ + inputs = [ + NamedArgument( + name="", arg=self.serialize_input(a), kind=ArgumentKind.POSITIONAL + ) + for a in args + ] + inputs.extend( + [ + NamedArgument( + name=name, + arg=self.serialize_input(a), + kind=ArgumentKind.KEYWORD, + ) + for name, a in kwargs.items() + ] + ) + return inputs + + def is_inductor_sym_int_arg(self, arg) -> bool: + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node and should be + # verified with is_sym_int_arg() + return type(arg) is int or isinstance(arg, torch.SymInt) + + def is_sym_int_arg(self, arg) -> bool: + return type(arg) is int or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_int_values + ) + + def is_sym_float_arg(self, arg) -> bool: + return isinstance(arg, float) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_float_values + ) + + def is_sym_bool_arg(self, arg) -> bool: + return isinstance(arg, bool) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_bool_values + ) + + # should be torch._C.JitType but that annotation is busted + def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: + import torch._inductor.ir as inductor_ir + + inductor_tensor_buffers = ( + inductor_ir.Buffer, + inductor_ir.ReinterpretView, + ) + + if isinstance(arg, torch.fx.Node): + if arg.op == "get_attr": + assert isinstance(arg.target, str) + attr = getattr(arg.graph.owning_module, arg.target) + + if isinstance(attr, torch.Tensor): + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + elif isinstance(attr, torch.fx.GraphModule): + with self.save_graph_state(): + graph = self.serialize_graph(attr) + return Argument.create( + as_graph=GraphArgument(name=arg.target, graph=graph) + ) + elif type(attr).__name__ == "LoweredBackendModule": + # Special handling for executorch_call_delegate HOP + # It's first argument is a LoweredBackendModule, for which we + # serialize name and backend id of the lowered module + module_name = getattr(attr, "module_name", None) + backend_id = getattr(attr, "backend_id", None) + assert module_name is not None, "module_name should not be None" + assert backend_id is not None, "backend_id should not be None" + return Argument.create(as_string=f"{module_name}-{backend_id}") + else: + raise SerializeError( + f"Unsupported getattr attribute {arg.target} with type: {type(attr)}" + ) + elif self.is_sym_int_arg(arg): + return Argument.create( + as_sym_int=SymIntArgument.create(as_name=arg.name) + ) + elif self.is_sym_float_arg(arg): + return Argument.create( + as_sym_float=SymFloatArgument.create(as_name=arg.name) + ) + elif self.is_sym_bool_arg(arg): + return Argument.create( + as_sym_bool=SymBoolArgument.create(as_name=arg.name) + ) + elif isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument( + name=arg.name, class_fqn=arg.meta["val"].class_fqn + ) + ) + elif arg.name in self.duplicate_getitem_nodes: + dedup_name = self.duplicate_getitem_nodes[arg.name] + return Argument.create(as_tensor=TensorArgument(name=dedup_name)) + else: + return Argument.create(as_tensor=TensorArgument(name=arg.name)) + elif isinstance(arg, inductor_tensor_buffers): + # Other branches are for arguments in fx node. + # This is a special branch for handling buffers (representing tensor arguments) + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + return Argument.create(as_tensor=TensorArgument(name=arg_name)) + elif isinstance(arg, inductor_ir.TorchBindObject): + # This is a special branch for handling TorchBindObject + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + arg_val = arg.get_real_obj() + class_fqn = arg_val._type().qualified_name() + self.custom_objs[arg_name] = arg_val + return Argument.create(as_custom_obj=CustomObjArgument(arg_name, class_fqn)) + elif isinstance(arg, torch.SymInt): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) + elif isinstance(arg, torch.SymFloat): + # This is a special branch for handling SymFloat args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_float_arg(arg) being true + return Argument.create( + as_sym_float=SymFloatArgument.create(as_name=str(arg)) + ) + elif type(arg) is bool: + return Argument.create(as_bool=arg) + elif type(arg) is str: + return Argument.create(as_string=arg) + elif type(arg) is int: + return Argument.create(as_int=arg) + elif type(arg) is float: + return Argument.create(as_float=arg) + elif type(arg) is complex: + return Argument.create( + as_complex=ComplexValue(real=arg.real, imag=arg.imag) + ) + elif arg is None: + return Argument.create(as_none=True) + elif isinstance(arg, dict): + serialized_dict = {} + for key, value in arg.items(): + if not isinstance(key, str): + raise SerializeError(f"Dict keys must be strings, got {type(key)}") + serialized_dict[key] = self.serialize_input(value) + return Argument.create(as_string_to_argument=serialized_dict) + elif isinstance(arg, (list, tuple)): + if len(arg) == 0: + if arg_type is not None: + if isinstance(arg_type, torch.OptionalType): + arg_type = arg_type.getElementType() # type: ignore[assignment] + assert isinstance(arg_type, torch.ListType) + elem_type = arg_type.getElementType() + if isinstance(elem_type, torch.OptionalType): + elem_type = elem_type.getElementType() + + if isinstance(elem_type, torch.BoolType): + return Argument.create(as_bools=[]) + elif isinstance(elem_type, torch.IntType): + return Argument.create(as_ints=[]) + elif isinstance(elem_type, torch.FloatType): + return Argument.create(as_floats=[]) + elif isinstance(elem_type, torch.StringType): + return Argument.create(as_strings=[]) + elif isinstance(elem_type, torch.TensorType): + return Argument.create(as_tensors=[]) + else: + # I believe empty symint lists default to ints, but + # please file an issue if this is not the case + raise SerializeError(f"Empty list with type {elem_type} nyi.") + else: + # We could serialize this by default to a tensor list. This + # is needed in the HOO case + log.warning( + "Unsure how to serialize the given empty list, " + "as we don't know what is the type of this argument. " + "Serializing it as a tensor list by default." + ) + return Argument.create(as_tensors=[]) + + if all(type(a) is bool for a in arg): + return Argument.create(as_bools=list(arg)) + elif all(type(a) is int for a in arg): + return Argument.create(as_ints=list(arg)) + elif all(type(a) is float for a in arg): + return Argument.create(as_floats=list(arg)) + elif all(type(a) is str for a in arg): + return Argument.create(as_strings=list(arg)) + elif all(self.is_inductor_sym_int_arg(a) for a in arg): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node + values = [] + for a in arg: + if isinstance(a, torch.SymInt): + values.append(SymIntArgument.create(as_name=str(a))) + elif type(a) is int: + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(isinstance(a, torch.SymFloat) for a in arg): + return Argument.create( + as_sym_floats=[SymFloatArgument.create(as_name=str(a)) for a in arg] + ) + elif all(self.is_sym_int_arg(a) for a in arg): + # list of sym_ints + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymIntArgument.create(as_name=a.name)) + elif type(a) is int: + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(self.is_sym_float_arg(a) for a in arg): + # list of sym_float + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymFloatArgument.create(as_name=a.name)) + elif isinstance(a, float): + values.append(SymFloatArgument.create(as_float=a)) + return Argument.create(as_sym_floats=values) + elif all(self.is_sym_bool_arg(a) for a in arg): + # list of sym_bools + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymBoolArgument.create(as_name=a.name)) + elif isinstance(a, bool): + values.append(SymBoolArgument.create(as_bool=a)) + return Argument.create(as_sym_bools=values) + elif all(isinstance(a, torch.fx.Node) for a in arg): + # list of tensors + arguments = [] + for a in arg: + if a.op == "get_attr": + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + arguments.append(TensorArgument(name=a.name)) + return Argument.create(as_tensors=arguments) + elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): + # list of optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=True) + elif isinstance(a, torch.fx.Node): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.name) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all(isinstance(a, inductor_tensor_buffers) for a in arg): + # list of inductor buffers + return Argument.create( + as_tensors=[TensorArgument(name=a.get_name()) for a in arg], + ) + elif all( + isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg + ): + # list of inductor buffers as optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=True) + elif isinstance(a, inductor_tensor_buffers): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.get_name()) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all( + isinstance(a, tuple) and all(type(x) is int for x in a) for a in arg + ): + # list of int tuples + return Argument.create(as_int_lists=[list(t) for t in arg]) + else: + raise SerializeError( + f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" + ) + elif isinstance(arg, torch.dtype): + return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) + elif isinstance(arg, torch.device): + return Argument.create(as_device=Device(type=arg.type, index=arg.index)) + elif isinstance(arg, torch.memory_format): + return Argument.create( + as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg] + ) + elif isinstance(arg, torch.layout): + return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) + elif isinstance(arg, torch._C.ScriptObject): + if not ( + arg._has_method("__getstate__") # type: ignore[attr-defined] + and arg._has_method("__setstate__") # type: ignore[attr-defined] + ): + raise SerializeError( + f"Unable to serialize custom class {arg}. Please define " + "serialization methods via def_pickle()." + ) + # Custom objects through torchind are serializable with pickle, + # through implementing the .def_pickle function. This should result + # in the object containing a __getstate__ and __setstate__ + # serialize/deserialize function. + custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" + self.custom_objs[custom_obj_name] = arg + class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] + return Argument.create( + as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn) + ) + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return Argument.create(as_operator=self.serialize_operator(arg)) + else: + raise SerializeError( + f"Unsupported argument type: {type(arg)} with schema arg_type {arg_type}" + ) + + def serialize_tensor_output(self, name, meta_val) -> TensorArgument: + assert name not in self.graph_state.tensor_values + self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val) + return TensorArgument(name=name) + + def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_int_values + self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val) + return SymIntArgument.create(as_name=name) + + def serialize_sym_float_output(self, name, meta_val) -> SymFloatArgument: + assert name not in self.graph_state.sym_float_values + self.graph_state.sym_float_values[name] = serialize_sym_float(meta_val) + return SymFloatArgument.create(as_name=name) + + def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_bool_values + self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) + return SymBoolArgument.create(as_name=name) + + def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: + log.debug("[serialize_input_spec] %s", spec) + if spec.kind == ep.InputKind.USER_INPUT: + if isinstance(spec.arg, ep.ConstantArgument): + if type(spec.arg.value) is int: + constant_spec = ConstantValue.create(as_int=spec.arg.value) + elif type(spec.arg.value) is bool: + constant_spec = ConstantValue.create(as_bool=spec.arg.value) + elif type(spec.arg.value) is str: + constant_spec = ConstantValue.create(as_string=spec.arg.value) + elif type(spec.arg.value) is float: + constant_spec = ConstantValue.create(as_float=spec.arg.value) + elif spec.arg.value is None: + constant_spec = ConstantValue.create(as_none=True) + else: + raise SerializeError( + f"Unhandled constant input {spec.arg.value} to serialize" + ) + return InputSpec.create( + constant_input=InputToConstantInputSpec( + name=spec.arg.name, value=constant_spec + ) + ) + else: + return InputSpec.create( + user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg)) + ) + elif spec.kind == ep.InputKind.PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + parameter=InputToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.BUFFER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + assert spec.persistent is not None + return InputSpec.create( + buffer=InputToBufferSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + persistent=spec.persistent, + ) + ) + elif spec.kind == ep.InputKind.CONSTANT_TENSOR: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + tensor_constant=InputToTensorConstantSpec( + arg=TensorArgument(name=spec.arg.name), + tensor_constant_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.CUSTOM_OBJ: + assert spec.target is not None + assert isinstance(spec.arg, ep.CustomObjArgument) + return InputSpec.create( + custom_obj=InputToCustomObjSpec( + arg=CustomObjArgument( + name=spec.arg.name, class_fqn=spec.arg.class_fqn + ), + custom_obj_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return InputSpec.create( + token=InputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: + log.debug("[serialize_output_spec] %s", spec) + if spec.kind == ep.OutputKind.USER_OUTPUT: + return OutputSpec.create( + user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg)) + ) + elif spec.kind == ep.OutputKind.LOSS_OUTPUT: + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name)) + ) + elif spec.kind == ep.OutputKind.BUFFER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + buffer_mutation=BufferMutationSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.PARAMETER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + parameter_mutation=ParameterMutationSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_parameter=GradientToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_user_input=GradientToUserInputSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + user_input_mutation=UserInputMutationSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return OutputSpec.create( + token=OutputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: + log.debug("\n[serialize_signature]") + return GraphSignature( + input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], + output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], + ) + + def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: + if isinstance(x, ep.TensorArgument): + return Argument.create(as_tensor=TensorArgument(name=x.name)) + elif isinstance(x, ep.SymIntArgument): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) + elif isinstance(x, ep.SymFloatArgument): + return Argument.create(as_sym_float=SymFloatArgument.create(as_name=x.name)) + elif isinstance(x, ep.ConstantArgument): + return self.serialize_input(x.value) + elif isinstance(x, ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn) + ) + else: + raise AssertionError("TODO") + + def serialize_treespec(self, treespec: pytree.TreeSpec) -> str: + # We want to additionally save all the field names of the namedtuples in + # case users want to check that the treespec types are equivalent + def store_namedtuple_fields(ts: pytree.TreeSpec) -> None: + if ts.type is None: + return + if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): + serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ + ts.context + ].serialized_type_name + if serialized_type_name in self.treespec_namedtuple_fields: + field_names = self.treespec_namedtuple_fields[ + serialized_type_name + ].field_names + if field_names != ts.context._fields: + raise SerializeError( + f"The given TreeSpec's namedtuple type {ts.context} " + f"was found to have field names {ts.context._fields} " + f"but somehow previously was found to have field names {field_names}." + ) + else: + self.treespec_namedtuple_fields[serialized_type_name] = ( + NamedTupleDef(field_names=ts.context._fields) + ) + + for child in ts.children(): + store_namedtuple_fields(child) + + serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION) + store_namedtuple_fields(treespec) + return serialized_treespec + + def serialize_module_call_signature( + self, module_call_signature: ep.ModuleCallSignature + ) -> ModuleCallSignature: + log.debug("[serialize_module_call_signature] %s", module_call_signature) + return ModuleCallSignature( + inputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=self.serialize_treespec(module_call_signature.in_spec), + out_spec=self.serialize_treespec(module_call_signature.out_spec), + forward_arg_names=names + if (names := module_call_signature.forward_arg_names) + else None, + ) + + def serialize_module_call_graph( + self, module_call_graph: list[ep.ModuleCallEntry] + ) -> list[ModuleCallEntry]: + log.debug("\n[serialize_module_call_graph]") + return [ + ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.serialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]: + """For a given node, return the dataclass representing its output values. + + [NOTE: Multiple outputs] We handle aggregates differently than FX. For + FX, it looks like: + + x = call_function("multiple_return", ...) + element0 = call_function(getitem, x, 0) + foo = call_function("use_output", element0) + + We do not want the intermediate `getitem` call, so our serialized thing looks like: + + element0, element1, element2 = call_function("multiple_return", ...) + foo = call_function("use_output", element0) + + We want names to be consistent across these two schemes, so that we can + mostly reuse the names coming from FX. This function computes a mapping from + the FX representation to our representation, preserving the names. + """ + + def _is_single_tensor_list_return(target: Any) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + + if len(returns) != 1: + return False + return_type = returns[0].real_type + return isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ) + + assert node.op == "call_function" and isinstance( + node.target, (torch._ops.OpOverload, *_registered_extension_types()) + ) + + schema = _get_schema_from_target(node.target) + returns = schema.returns + + if len(returns) == 0: + return [] + + meta_val = node.meta["val"] + + # Check single value return + if _is_single_tensor_list_return(node.target): + # e.g "-> Tensor[]" + tensor_args = [] + for idx, meta in enumerate(meta_val): + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + elif len(returns) == 1: + return [self.serialize_output(node.name, meta_val)] + + # There are a two possibilities at this point: + # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)" + # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])" + # + # Either way, start by gathering a list of TensorArguments with the correct names. + # For consistent naming with FX, consult the downstream `getitem` node and + # make sure our outputs have the same name. + + output_arguments = [] + for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): + if meta is None: + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) + # When the return type is annotated as Tensor type, the op can also return an + # undefined Tensor which will be implicitly converted to None in Python. + output_arguments.append(Argument.create(as_none=True)) + elif isinstance(meta, FakeTensor): + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) + name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(name, meta)) + elif isinstance(meta, list): + # for List[Tensor] return type + assert isinstance( + return_schema.real_type, torch.ListType + ) and isinstance( + return_schema.real_type.getElementType(), torch.TensorType + ) + user_node = self._output_node_at_index(node, idx) + assert user_node is not None + + args = [] + for i, m in enumerate(meta): + if m is None: + continue + sub_user_node_name = self._output_node_name_at_index(user_node, i) + args.append(self.serialize_tensor_output(sub_user_node_name, m)) + output_arguments.append(Argument.create(as_tensors=args)) + elif isinstance(meta, (int, SymInt, float, SymFloat)): + user_node_name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(user_node_name, meta)) + else: + raise ValueError( + f"Unhandled output type {type(meta)} from node {node.format_node()}" + ) + + return output_arguments + + def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]: + """ + For serializing HOO outputs since HOOs do not have a schema. + """ + meta_val = node.meta["val"] + + if isinstance(meta_val, tuple): + outputs = [] + for i, element_meta_val in enumerate(meta_val): + user_node = self._output_node_at_index(node, i) + if isinstance(element_meta_val, list): + # e.g "-> Tensor[]" + assert user_node is not None + + tensors = [] + for j, m in enumerate(element_meta_val): + if not isinstance(m, torch.Tensor): + raise SerializeError( + f"Serialize list output with type {type(m)} nyi" + ) + + name = self._output_node_name_at_index(user_node, j) + tensors.append(self.serialize_tensor_output(name, m)) + outputs.append(Argument.create(as_tensors=tensors)) + + else: + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{i}" + ) + + outputs.append(self.serialize_output(name, element_meta_val)) + + return outputs + elif isinstance(meta_val, dict): + tensor_args = [] + # use the dict key as the idx + for idx, meta in meta_val.items(): + if not isinstance(meta, torch.Tensor): + raise SerializeError( + f"Serialize list output with type {type(meta)} nyi" + ) + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + else: + return [self.serialize_output(node.name, meta_val)] + + def serialize_output(self, name: str, meta_val: Any) -> Argument: + # Check single value return + if meta_val is None: + return Argument.create(as_none=True) + if isinstance(meta_val, torch.Tensor): + # e.g "-> Tensor" + return Argument.create( + as_tensor=self.serialize_tensor_output(name, meta_val) + ) + elif isinstance(meta_val, (bool, torch.SymBool)): + # e.g "-> SymBool" + return Argument.create( + as_sym_bool=self.serialize_sym_bool_output(name, meta_val) + ) + elif isinstance(meta_val, (int, torch.SymInt)): + # e.g "-> SymInt" + assert not isinstance(meta_val, bool) + return Argument.create( + as_sym_int=self.serialize_sym_int_output(name, meta_val) + ) + elif isinstance(meta_val, (float, torch.SymFloat)): + # e.g "-> SymFloat" + return Argument.create( + as_sym_float=self.serialize_sym_float_output(name, meta_val) + ) + + # list outputs should've been handled earlier + raise SerializeError(f"Unable to serialize output {meta_val}") + + def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]: + meta_val = node.meta["val"] + + idx_to_name = {} + for user in node.users: + assert user.target is operator.getitem, ( + f"User node {user} of {node} is incorrect" + ) + idx_to_name[user.args[1]] = user.name + + for idx, _ in enumerate(meta_val): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + arg_list = [] + for i, element_meta_val in enumerate(meta_val): + arg_list.append( + self.serialize_tensor_output(idx_to_name[i], element_meta_val) + ) + + return arg_list + + def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: + assert isinstance(graph_module, torch.fx.GraphModule) + log.debug( + "[serialize_graph]\n\n%s", graph_module.print_readable(print_output=False) + ) + + for node in graph_module.graph.nodes: + try: + getattr(self, f"handle_{node.op}")(node) + except Exception as e: + raise SerializeError( + f"Failed serializing node {node} in graph: {node.format_node()}\n Original exception {traceback.format_exc()}" + ) from e + + return Graph( + inputs=self.graph_state.inputs, + nodes=self.graph_state.nodes, + tensor_values=self.graph_state.tensor_values, + sym_int_values=self.graph_state.sym_int_values, + sym_float_values=self.graph_state.sym_float_values, + sym_bool_values=self.graph_state.sym_bool_values, + custom_obj_values=self.graph_state.custom_obj_values, + outputs=self.graph_state.outputs, + is_single_tensor_return=self.graph_state.is_single_tensor_return, + ) + + def serialize_graph_module_metadata(self, meta: dict[str, Any]): + ret = {} + if custom := meta.get("custom"): + log.debug("\n[serialize_graph_module_metadata] %s", custom) + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for graph with error {e}" + ) from e + + return ret + + def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule: + log.debug("\n[serialize]") + graph = self.serialize_graph(graph_module) + + return GraphModule( + graph=graph, + signature=self.serialize_signature(self.graph_signature), + module_call_graph=self.serialize_module_call_graph(self.module_call_graph), + metadata=self.serialize_graph_module_metadata(graph_module.meta), + treespec_namedtuple_fields=self.treespec_namedtuple_fields, + ) + + +@final +class ExportedProgramSerializer(metaclass=Final): + def __init__( + self, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, + ): + self.opset_version: dict[str, int] = {} + if opset_version: + self.opset_version.update(opset_version) + if "aten" not in self.opset_version: + self.opset_version["aten"] = torch._C._get_max_operator_version() + + self.pickle_protocol = pickle_protocol + + def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: + """ + Args: + exported_program: Exported Program to serialize + """ + exported_program.validate() + + gm_serializer = GraphModuleSerializer( + exported_program.graph_signature, exported_program.module_call_graph + ) + serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) + serialized_range_constraints = serialize_range_constraints( + exported_program.range_constraints + ) + + # TODO: Directly serialize exported_program.constants once + # CustomClassHolders get stored in the ExportedProgram rather than in + # the graph + constants: dict[str, Any] = gm_serializer.custom_objs.copy() + for n, t in exported_program.constants.items(): + assert n not in constants + constants[n] = t + + serialized_ep = ExportedProgram( + graph_module=serialized_graph_module, + opset_version=self.opset_version, + range_constraints=serialized_range_constraints, + schema_version=SchemaVersion( + major=SCHEMA_VERSION[0], + minor=SCHEMA_VERSION[1], + ), + verifiers=[v.dialect for v in exported_program.verifiers], + torch_version=torch.__version__, + guards_code=exported_program._guards_code, + ) + + # Test canonical form is well defined. + canonicalize(serialized_ep, set(constants.keys())) + + # Proxy cannot be dumped, so we remove them. + new_state_dict = remove_proxy_from_state_dict( + exported_program.state_dict, in_place=False + ) + return _SerializedProgram( + serialized_ep, + serialize_torch_artifact(new_state_dict, self.pickle_protocol), + serialize_torch_artifact(constants, self.pickle_protocol), + serialize_torch_artifact( + exported_program.example_inputs, self.pickle_protocol + ), + ) + + +@final +class GraphModuleDeserializer(metaclass=Final): + @dataclasses.dataclass + class Result: + graph_module: torch.fx.GraphModule + signature: ep.ExportGraphSignature + module_call_graph: list[ep.ModuleCallEntry] + names_to_symbols: dict[str, sympy.Symbol] + state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]] + constants: dict[str, _ConstantAttributeType] + example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]] + + def __init__(self) -> None: + self.serialized_name_to_node: dict[str, torch.fx.Node] = {} + self.serialized_name_to_meta: LazyMap = LazyMap() # str -> MetaType + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + + @contextmanager + def save_graph_module(self) -> Iterator[None]: + saved = ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + self.unbacked_symbols, + ) + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + self.serialized_name_to_node = {} + self.serialized_name_to_meta = LazyMap() + self.unbacked_symbols: set[sympy.Symbol] = set() + try: + yield + finally: + ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + self.unbacked_symbols, + ) = saved + + def deserialize_extension_operator(self, serialized_target: str): + namespace, op_name = serialized_target.split(":") + namespace = namespace[1:] # starting with # + handler = _deserialization_registry[namespace] + return handler.from_op_name(op_name) + + def deserialize_operator(self, serialized_target: str): + if serialized_target.startswith( + "_operator" + ): # TODO(zhxchen17) Follow up on this. + module = operator + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("torch"): + module = torch # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("math"): + module = math # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("#"): + return self.deserialize_extension_operator(serialized_target) + else: # TODO(zhxchen17) Don't catch all here. + return serialized_target + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + def _parse_sym_expr( + self, expr_str: str, hint: Optional[Union[int, bool, float]] = None + ) -> sympy.Expr: + """ + Parses and does bottom-up processing of sympy.Expr nodes, + populating ShapeEnv & caching symbols as needed. + """ + + def _process_sym_expr( + sym: sympy.Expr, hint: Optional[Union[int, bool, float]] = None + ) -> sympy.Expr: + if sym.is_Integer or sym.is_Float or sym.is_Boolean: # base case + return sym + else: # recursive case + # important to use str(expr) and not _print_sympy(), + # str(expr) is key for self.symbol_name_to_range + expr_str = str(sym) + for arg in sym.args: + self._parse_sym_expr(arg) + # symbol caching + if expr_str in self.symbol_name_to_symbol: + sym = self.symbol_name_to_symbol[expr_str] + else: + self.symbol_name_to_symbol[expr_str] = sym + if isinstance(sym, sympy.Symbol) and symbolic_shapes.symbol_is_type( + sym, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT) + ): + self.unbacked_symbols.add(sym) + # hints + if hint is not None and sym not in self.shape_env.var_to_val: + self.shape_env.add_var_to_val(sym, hint) # type: ignore[arg-type] + # ValueRanges + if vr := self.symbol_name_to_range.get(expr_str): + self.shape_env.constrain_symbol_range( + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + # ShapeEnv meta + if isinstance(sym, sympy.Symbol): + self.shape_env.var_to_stack[sym] = CapturedTraceback.extract(skip=1) + return sym + + expr = sympy.sympify( + expr_str, + locals={**self.sympy_functions, **self.symbol_name_to_symbol}, + ) + return _process_sym_expr(expr, hint) + + def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: + val = s.value + if s.type == "as_expr": + if val.hint is None: + hint = None + else: + assert val.hint.type == "as_int" + hint = val.hint.value + + sym = self._parse_sym_expr(val.expr_str, hint) + return self.shape_env.create_symintnode(sym, hint=hint) + elif s.type == "as_int": + assert type(val) is int + return val + else: + raise SerializeError( + f"SymInt has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_float(self, s: SymFloat) -> Union[float, torch.SymFloat]: + val = s.value + if s.type == "as_expr": + hint = val.hint.as_float if val.hint else None + sym = self._parse_sym_expr(val.expr_str, hint) + return self.shape_env.create_symfloatnode(sym, hint=hint) + elif s.type == "as_float": + assert isinstance(val, float) + return val + else: + raise SerializeError( + f"SymFloat has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: + val = s.value + if s.type == "as_expr": + expr = self._parse_sym_expr(val.expr_str) + return self.shape_env.create_symboolnode(expr) + elif s.type == "as_bool": + assert isinstance(val, bool) + return val + else: + raise SerializeError( + f"SymBool has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_tensor_meta( + self, + tensor_meta: TensorMeta, + ) -> FakeTensor: + with self.fake_tensor_mode: + return cast( + FakeTensor, + torch.empty_strided( + tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc] + tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc] + device=deserialize_device(tensor_meta.device), + dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], + requires_grad=tensor_meta.requires_grad, + ), + ) + + def deserialize_script_obj_meta( + self, script_obj_meta: CustomObjArgument + ) -> ep.CustomObjArgument: + return ep.CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]: + if output.type == "as_tensor": + return self.serialized_name_to_node[output.as_tensor.name] + elif output.type == "as_sym_int": + return self.serialized_name_to_node[output.as_sym_int.as_name] + elif output.type == "as_sym_bool": + return self.serialized_name_to_node[output.as_sym_bool.as_name] + elif output.type == "as_sym_float": + return self.serialized_name_to_node[output.as_sym_float.as_name] + elif output.type == "as_int": + return output.as_int + elif output.type == "as_float": + return output.as_float + elif output.type == "as_bool": + return output.as_bool + elif output.type == "as_none": + return None + else: + raise SerializeError(f"Unable to deserialize output node {output}") + + def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: + log.debug("\n[deserialize_graph]") + + # Handle the tensor metas. + for name, tensor_value in serialized_graph.tensor_values.items(): + log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value) + self.serialized_name_to_meta[name] = ( + lambda v=tensor_value: self.deserialize_tensor_meta(v) + ) + + for name, sym_int_value in serialized_graph.sym_int_values.items(): + log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value) + self.serialized_name_to_meta[name] = ( + lambda v=sym_int_value: self.deserialize_sym_int(v) + ) + + for name, sym_float_value in serialized_graph.sym_float_values.items(): + log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value) + self.serialized_name_to_meta[name] = ( + lambda v=sym_float_value: self.deserialize_sym_float(v) + ) + + for name, sym_bool_value in serialized_graph.sym_bool_values.items(): + log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value) + self.serialized_name_to_meta[name] = ( + lambda v=sym_bool_value: self.deserialize_sym_bool(v) + ) + + for name, script_obj_meta in serialized_graph.custom_obj_values.items(): + log.debug("[deserialize_script_obj_meta] %s", script_obj_meta) + self.serialized_name_to_meta[name] = ( + lambda v=script_obj_meta: self.deserialize_script_obj_meta(v) + ) + + log.debug("\n[deserialize graph nodes]") + # Inputs: convert to placeholder nodes in FX. + for i, input_ in enumerate(serialized_graph.inputs): + log.debug("[deserialize input] %s", input_) + if input_.type in ("as_tensor", "as_custom_obj"): + node_name = input_.value.name + placeholder_node = self.graph.placeholder(node_name) + # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) + # we will overwrite it + placeholder_node.name = node_name + self.sync_fx_node(node_name, placeholder_node) + elif input_.type == "as_sym_int": + if input_.value.type == "as_name": + node_name = input_.value.as_name + placeholder_node = self.graph.placeholder(node_name) + # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) + # we will overwrite it + placeholder_node.name = node_name + self.sync_fx_node(node_name, placeholder_node) + else: + raise SerializeError( + f"Deserializing a constant symint {input_.value} as an input" + ) + elif input_.type in ( + "as_int", + "as_float", + "as_bool", + "as_none", + "as_string", + ): + node_name = self.signature.input_specs[i].arg.name or f"arg{i}" + placeholder_node = self.graph.placeholder(node_name) + placeholder_node.meta["val"] = self.deserialize_input(input_) + else: + raise SerializeError(f"Invalid input type {input_}") + + # Nodes: convert to call_function nodes. + for serialized_node in serialized_graph.nodes: + try: + target = self.deserialize_operator(serialized_node.target) + self.deserialize_node(serialized_node, target) + + except Exception as e: + raise SerializeError( + f"Failed deserializing node {serialized_node}\n Original exception {traceback.format_exc()}" + ) from e + + # Outputs: convert to a single `output` node. + outputs = [] + for output in serialized_graph.outputs: + log.debug("[deserialize output] %s", output) + outputs.append(self.deserialize_graph_output(output)) + + if serialized_graph.is_single_tensor_return: + assert len(outputs) == 1 + outputs = outputs[0] # type: ignore[assignment] + else: + outputs = tuple(outputs) # type: ignore[assignment] + + output_node = self.graph.output(outputs) + + if serialized_graph.is_single_tensor_return: + output_node.meta["val"] = output_node.args[0].meta["val"] + else: + output_node.meta["val"] = tuple( + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ) + + # recompute unbacked bindings + for node in self.graph.nodes: + if (val := node.meta.get("val")) is not None and ( + unbacked_bindings := symbolic_shapes._free_unbacked_symbols_with_path( + val, + (), + shape_env=self.shape_env, + pending=self.unbacked_symbols, + simplify=True, + ) + ): + node.meta["unbacked_bindings"] = unbacked_bindings + + assert len(self.unbacked_symbols) == 0 + return self.graph + + def deserialize_node(self, serialized_node: Node, target: Callable) -> None: + def _is_single_tensor_return(target) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + return len(returns) == 1 and isinstance( + returns[0].real_type, torch.TensorType + ) + + if ( + target in _SYM_OPS + or target + == torch.ops.aten.item.default # this can produce either SymInt or SymBool + ): + name = serialized_node.outputs[0].value.as_name + args = self.deserialize_sym_op_inputs(serialized_node.inputs) + + fx_node = self.graph.create_node("call_function", target, args, {}, name) + self.deserialize_sym_op_outputs(serialized_node, fx_node) + elif ( + target + is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional + ): + raise SerializeError( + "deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional" + ) + elif isinstance(target, torch._ops.HigherOrderOperator): + args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) + metadata = self.deserialize_metadata(serialized_node.metadata) + for x in (*args, *kwargs.values()): + if isinstance(x, torch.fx.Node) and x.op == "get_attr": + # this means that we have deserialized a graph argument, but + # unfortunately the schema for it does not include metadata; + # so we reuse the metadata of the HOP call for such arguments + x.meta.update(metadata) + # If a serialized HOP node has a length=1 outputs of type `as_tensor``. + # There could be two cases: + # (1) The HOP node returns a single tensor + # (2) The HOP node returns a tuple containing a single tensor + # We distinguish (1) and (2) by the `is_single_tensor_return` + # field in the schema of Node + # For BC, getattr() will return True if `is_single_tensor_return` doesn't + # exist. This is because prior to adding `is_single_tensor_return`, + # only (1) could happen as we handle (2) with type `as_tensors` + name = ( + serialized_node.outputs[0].as_tensor.name + if len(serialized_node.outputs) == 1 + and hasattr(serialized_node.outputs[0], "as_tensor") + and getattr(serialized_node, "is_hop_single_tensor_return", True) + else None + ) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + fx_node.meta.update(metadata) + + elif isinstance( + target, (torch._ops.OpOverload, *_registered_extension_types()) + ): + # For convenience: if this node returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + name = ( + serialized_node.outputs[0].as_tensor.name + if _is_single_tensor_return(target) + else None # FX will generate a name for us. + ) + args, kwargs = self.deserialize_inputs(target, serialized_node) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + else: + _additional_msg = ( + ( + f"We failed to resolve {target} to an operator. " + + "If it's a custom op/custom triton op, this is usually because the custom op is not registered" + + " when deserializing. Please import the custom op to register it before deserializing." + + " Otherwise, please file an issue on github." + ) + if isinstance(target, str) + else "" + ) + raise SerializeError( + _additional_msg + + f" Unsupported target type for node {serialized_node}: {type(target)}." + ) + + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + log.debug( + "[deserialize_node] %s: %s(%s, {%s}) -> %s", + fx_node.name, + fx_node.target, + fx_node.args, + fx_node.kwargs, + fx_node.meta.get("val"), + ) + + # handle ShapeEnv asserts + if target is torch.ops.aten._assert_scalar.default: + if not isinstance((arg := fx_node.args[0]), bool): + expr = arg.meta["val"] # type: ignore[union-attr] + if isinstance(expr, torch.SymBool): + self.shape_env.guard_or_defer_runtime_assert( + expr.node.expr, "", fx_node + ) + elif target is torch.ops.aten.sym_constrain_range_for_size.default: + sym = fx_node.args[0].meta["val"] # type: ignore[union-attr] + if isinstance(sym, torch.SymInt): + self.shape_env._constrain_range_for_size(sym.node.expr) + + # handle nn_module_stack; serialization throws away empty dicts + if ( + fx_node.op not in ["placeholder", "output"] + and "nn_module_stack" not in fx_node.meta + ): + fx_node.meta["nn_module_stack"] = {} + + def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: + log.debug("[deserialize_input_spec] %s", i) + if i.type == "user_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=self.deserialize_argument_spec(i.user_input.arg), + target=None, + ) + elif i.type == "parameter": + return ep.InputSpec( + kind=ep.InputKind.PARAMETER, + arg=ep.TensorArgument(name=i.parameter.arg.name), + target=i.parameter.parameter_name, + ) + elif i.type == "buffer": + return ep.InputSpec( + kind=ep.InputKind.BUFFER, + arg=ep.TensorArgument(name=i.buffer.arg.name), + target=i.buffer.buffer_name, + persistent=i.buffer.persistent, + ) + elif i.type == "tensor_constant": + return ep.InputSpec( + kind=ep.InputKind.CONSTANT_TENSOR, + arg=ep.TensorArgument(name=i.tensor_constant.arg.name), + target=i.tensor_constant.tensor_constant_name, + ) + elif i.type == "custom_obj": + return ep.InputSpec( + kind=ep.InputKind.CUSTOM_OBJ, + arg=ep.CustomObjArgument( + name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn + ), + target=i.custom_obj.custom_obj_name, + ) + elif i.type == "token": + return ep.InputSpec( + kind=ep.InputKind.TOKEN, + arg=ep.TokenArgument(name=i.token.arg.name), + target=None, + ) + elif i.type == "constant_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=ep.ConstantArgument( + name=i.constant_input.name, + value=self.deserialize_constant_input(i.constant_input.value), + ), + target=None, + ) + else: + raise AssertionError(f"Unknown input spec {i}") + + def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: + log.debug("[deserialize_output_spec] %s", o) + if o.type == "user_output": + return ep.OutputSpec( + kind=ep.OutputKind.USER_OUTPUT, + arg=self.deserialize_argument_spec(o.user_output.arg), + target=None, + ) + elif o.type == "loss_output": + return ep.OutputSpec( + kind=ep.OutputKind.LOSS_OUTPUT, + arg=ep.TensorArgument(name=o.loss_output.arg.name), + target=None, + ) + elif o.type == "buffer_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.BUFFER_MUTATION, + arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), + target=o.buffer_mutation.buffer_name, + ) + elif o.type == "parameter_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.PARAMETER_MUTATION, + arg=ep.TensorArgument(name=o.parameter_mutation.arg.name), + target=o.parameter_mutation.parameter_name, + ) + elif o.type == "gradient_to_parameter": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_PARAMETER, + arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), + target=o.gradient_to_parameter.parameter_name, + ) + elif o.type == "gradient_to_user_input": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, + arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), + target=o.gradient_to_user_input.user_input_name, + ) + elif o.type == "user_input_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.USER_INPUT_MUTATION, + arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), + target=o.user_input_mutation.user_input_name, + ) + elif o.type == "token": + return ep.OutputSpec( + kind=ep.OutputKind.TOKEN, + arg=ep.TokenArgument(name=o.token.arg.name), + target=None, + ) + else: + raise AssertionError(f"Unknown output spec {o}") + + def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: + log.debug("\n[deserialize_signature]") + return ep.ExportGraphSignature( + input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs], + ) + + def deserialize( + self, + serialized_graph_module: GraphModule, + serialized_state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, Any], bytes], + example_inputs: Optional[ + Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes] + ] = None, + symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None, + ) -> Result: + global _CURRENT_DESERIALIZER + assert _CURRENT_DESERIALIZER is None + _CURRENT_DESERIALIZER = self + try: + log.debug("\n[deserialize]") + self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.sympy_functions = { + # all torch.utils._sympy.functions should go here + # TODO(avik): find a better way to keep this collection in sync; + # e.g.., `exec('from torch.utils._sympy.functions import *', ...)` + # would work as long as the public API of that module is complete + "FloorDiv": torch.utils._sympy.functions.FloorDiv, + "ModularIndexing": torch.utils._sympy.functions.ModularIndexing, + "Where": torch.utils._sympy.functions.Where, + "PythonMod": torch.utils._sympy.functions.PythonMod, + "Mod": torch.utils._sympy.functions.Mod, + "CleanDiv": torch.utils._sympy.functions.CleanDiv, + "CeilToInt": torch.utils._sympy.functions.CeilToInt, + "FloorToInt": torch.utils._sympy.functions.FloorToInt, + "CeilDiv": torch.utils._sympy.functions.CeilDiv, + "LShift": torch.utils._sympy.functions.LShift, + "RShift": torch.utils._sympy.functions.RShift, + "PowByNatural": torch.utils._sympy.functions.PowByNatural, + "FloatPow": torch.utils._sympy.functions.FloatPow, + "FloatTrueDiv": torch.utils._sympy.functions.FloatTrueDiv, + "IntTrueDiv": torch.utils._sympy.functions.IntTrueDiv, + "IsNonOverlappingAndDenseIndicator": torch.utils._sympy.functions.IsNonOverlappingAndDenseIndicator, + "TruncToFloat": torch.utils._sympy.functions.TruncToFloat, + "TruncToInt": torch.utils._sympy.functions.TruncToInt, + "RoundToInt": torch.utils._sympy.functions.RoundToInt, + "RoundDecimal": torch.utils._sympy.functions.RoundDecimal, + "ToFloat": torch.utils._sympy.functions.ToFloat, + "Identity": torch.utils._sympy.functions.Identity, + } + self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {} + self.constants = deserialize_torch_artifact(constants) + self.signature = self.deserialize_signature( + serialized_graph_module.signature + ) + + # deserialization does analysis with checks on 0/1, so we create fake range constraints and + # restore the original range constraints afterwards + self.symbol_name_to_range = {} + # we also need to bump unbacked sym[float,int] counters in the + # shape env to accommodate unbacked symbols in the exported program + self.unbacked_symbols = set() + count_unbacked_symfloat, count_unbacked_symint = -1, -1 + unbacked_symfloat_prefix, unbacked_symint_prefix = ( + prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT] + ) + if symbol_name_to_range: + for k, vr in symbol_name_to_range.items(): + lower = vr.lower + self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges( + _int_to_sympy_int(lower, -int_oo), vr.upper + ) + if k.startswith(unbacked_symfloat_prefix): + i = int(k[len(unbacked_symfloat_prefix) :]) + count_unbacked_symfloat = max(count_unbacked_symfloat, i) + elif k.startswith(unbacked_symint_prefix): + i = int(k[len(unbacked_symint_prefix) :]) + count_unbacked_symint = max(count_unbacked_symint, i) + + # TODO(pianpwk): if we can clean up unused symbols in range_constraints, + # then this logic can just be handled with self.unbacked_symbols alone + for _ in range(count_unbacked_symfloat + 1): + self.shape_env.unbacked_symfloat_counter += 1 + for _ in range(count_unbacked_symint + 1): + self.shape_env.unbacked_symint_counter += 1 + + if example_inputs is not None and len(example_inputs) > 0: + self.example_inputs = deserialize_torch_artifact(example_inputs) + else: + self.example_inputs = None + self.deserialize_graph(serialized_graph_module.graph) + + with _enable_graph_inputs_of_type_nn_module(self.example_inputs): + module_call_graph = self.deserialize_module_call_graph( + serialized_graph_module.module_call_graph + ) + graph_module = ep._create_graph_module_for_export(self.module, self.graph) + meta = {} + if custom := serialized_graph_module.metadata.get("custom"): + meta["custom"] = json.loads(custom) + if hasattr(serialized_graph_module, "treespec_namedtuple_fields"): + meta["treespec_namedtuple_fields"] = {} + for ( + type_, + fields, + ) in serialized_graph_module.treespec_namedtuple_fields.items(): + meta["treespec_namedtuple_fields"][type_] = fields.field_names + graph_module.meta = meta + return GraphModuleDeserializer.Result( + graph_module=graph_module, + signature=self.signature, + module_call_graph=module_call_graph, + names_to_symbols=self.symbol_name_to_symbol, + state_dict=deserialize_torch_artifact(serialized_state_dict), + constants=self.constants, + example_inputs=self.example_inputs, + ) + finally: + _CURRENT_DESERIALIZER = None + + def sync_fx_node(self, name: str, fx_node: torch.fx.Node): + if name in self.serialized_name_to_node: + raise SerializeError(f"Node {name} has already been deserialized before.") + # overwrite name + fx_node.name = name + self.serialized_name_to_node[name] = fx_node + assert "val" not in fx_node.meta + fx_node.meta["val"] = self.serialized_name_to_meta[name] + + def deserialize_sym_op_inputs(self, inputs): + return tuple(self.deserialize_input(input.arg) for input in inputs) + + def deserialize_inputs(self, target, serialized_node: Node): + schema_args = _get_schema_from_target(target).arguments + argument_kinds = {input.name: input.kind for input in serialized_node.inputs} + actual_args = { + input.name: self.deserialize_input(input.arg) + for input in serialized_node.inputs + } + args = [] + kwargs: OrderedDict[str, Any] = OrderedDict() + for schema_arg in schema_args: + if schema_arg.name in actual_args: + arg = actual_args[schema_arg.name] + kind = argument_kinds[schema_arg.name] + if kind == ArgumentKind.POSITIONAL: + args.append(arg) + continue + elif kind == ArgumentKind.KEYWORD and not keyword.iskeyword( + schema_arg.name + ): + kwargs[schema_arg.name] = arg + continue + + # If there's no ArgumentKind found, fallback to the old cases. + is_positional = ( + not schema_arg.has_default_value() and not schema_arg.kwarg_only + ) + if is_positional: + args.append(actual_args[schema_arg.name]) + elif keyword.iskeyword(schema_arg.name): + assert not schema_arg.kwarg_only + if len(kwargs) > 0: + kwargs = OrderedDict() + args.extend(list(kwargs.values())) + args.append(actual_args[schema_arg.name]) + else: + if schema_arg.name in actual_args: + kwargs[schema_arg.name] = actual_args[schema_arg.name] + return tuple(args), kwargs + + def deserialize_hoo_inputs(self, inputs: list[NamedArgument]): + """ + For deserializing HOO inputs since HOOs do not have a schema. + """ + args = [] + kwargs = {} + for input_ in inputs: + if input_.name != "": + kwargs[input_.name] = self.deserialize_input(input_.arg) + else: + args.append(self.deserialize_input(input_.arg)) + return (tuple(args), kwargs) + + def deserialize_input(self, inp: Argument) -> Any: + value = inp.value + typ_ = inp.type + if typ_ == "as_none": + # None should converted as None, but is encoded as bool in serialized + # Convert serialized object to torch equivalent + return None + elif typ_ == "as_tensor": + return self.serialized_name_to_node[inp.as_tensor.name] + elif typ_ == "as_scalar_type": + return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] + elif typ_ == "as_memory_format": + return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] + elif typ_ == "as_layout": + return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] + elif typ_ == "as_graph": + assert isinstance(value, GraphArgument) + with self.save_graph_module(): + self.deserialize_graph(value.graph) + submodule = ep._create_graph_module_for_export(self.module, self.graph) + self.module.register_module(value.name, submodule) + return self.graph.create_node( + "get_attr", + value.name, + name=value.name, + ) + elif typ_ == "as_device": + return deserialize_device(inp.as_device) + elif typ_ == "as_int": + return inp.as_int + elif typ_ == "as_float": + return inp.as_float + elif typ_ == "as_bool": + return inp.as_bool + elif typ_ == "as_string": + return inp.as_string + elif typ_ == "as_complex": + return complex(inp.as_complex.real, inp.as_complex.imag) + elif typ_ == "as_sym_int": + return self.deserialize_sym_argument(inp.as_sym_int) + elif typ_ == "as_sym_float": + return self.deserialize_sym_argument(inp.as_sym_float) + elif typ_ == "as_sym_bool": + return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, dict): + if typ_ == "as_string_to_argument": + # Deserialize dict[str, Argument] recursively + return {k: self.deserialize_input(v) for k, v in value.items()} + else: + raise SerializeError(f"Unknown dict type: {typ_}") + elif isinstance(value, list): + if len(value) == 0: + return [] + elif typ_ == "as_tensors": + result = [self.serialized_name_to_node[arg.name] for arg in value] + return result + elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): + # convert from serialized.python.types.List to python list + return list(value) + elif typ_ == "as_int_lists": + # Convert list of lists back to list of tuples for Triton grids + return [tuple(dims) for dims in value] + elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"): + return [self.deserialize_sym_argument(arg) for arg in value] + elif typ_ == "as_optional_tensors": + + def deserialize_optional_tensor_args(a): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return self.serialized_name_to_node[a.value.name] + else: + raise SerializeError(f"Unhandled argument {inp}") + + return list(map(deserialize_optional_tensor_args, value)) + else: + raise SerializeError(f"Unhandled argument {inp}") + elif typ_ == "as_custom_obj": + if inp.as_custom_obj.name in self.serialized_name_to_node: + # Custom object has been lifted as an input + return self.serialized_name_to_node[inp.as_custom_obj.name] + return self.constants[inp.as_custom_obj.name] + elif typ_ == "as_operator": + return self.deserialize_operator(inp.as_operator) + else: + raise SerializeError(f"Unhandled argument {inp}") + + def deserialize_constant_input(self, inp: ConstantValue) -> Any: + if inp.type == "as_int": + return int(inp.as_int) + elif inp.type == "as_float": + return float(inp.as_float) + elif inp.type == "as_string": + return str(inp.as_string) + elif inp.type == "as_bool": + return bool(inp.as_bool) + elif inp.type == "as_none": + return None + else: + raise SerializeError(f"Unhandled constant argument {inp} to deserialize") + + def deserialize_sym_argument(self, sym_arg): + if isinstance(sym_arg, SymIntArgument): + if sym_arg.type == "as_int": + return sym_arg.as_int + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymFloatArgument): + if sym_arg.type == "as_float": + return sym_arg.as_float + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymBoolArgument): + if sym_arg.type == "as_bool": + return sym_arg.as_bool + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + raise SerializeError(f"Unknown symbolic argument type: {sym_arg}") + + def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + + def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + # Check single value return + if len(serialized_node.outputs) == 0: + return + + if ( + len(serialized_node.outputs) == 1 + and "torch.ops.higher_order" in serialized_node.target + and not getattr(serialized_node, "is_hop_single_tensor_return", True) + and serialized_node.outputs[0].type != "as_none" + ): + + def _deserialize_hop_with_single_return(serialized_node, fx_node): + meta_val: list[Any] = [] + arg = None + if serialized_node.outputs[0].type == "as_tensor": + arg = serialized_node.outputs[0].as_tensor + elif isinstance( + serialized_node.outputs[0].value, + (SymIntArgument, SymBoolArgument, SymFloatArgument), + ): + arg = serialized_node.outputs[0].value + deserialized_metadata = self.deserialize_metadata( + serialized_node.metadata + ) + assert arg is not None + # pyrefly: ignore [bad-argument-type] + self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + return + + return _deserialize_hop_with_single_return(serialized_node, fx_node) + + if ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_tensor" + ): + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + return + elif len(serialized_node.outputs) == 1 and isinstance( + serialized_node.outputs[0].value, + (SymIntArgument, SymBoolArgument, SymFloatArgument), + ): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + return + elif ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_none" + ): + # manually rename the node to a unused name to avoid naming conflicts + fx_node.meta["val"] = None + fx_node._rename(f"{self.graph._target_to_str(fx_node.target)}_unused") + return + + self.deserialize_multiple_outputs(serialized_node, fx_node) + + def generate_getitem( + self, + meta_val, + fx_node: torch.fx.Node, + arg: Union[TensorArgument, SymIntArgument, SymFloatArgument], + idx: int, + deserialized_metadata: dict[str, Any], + ): + if isinstance(arg, TensorArgument): + name = arg.name + elif isinstance(arg, SymIntArgument): + name = arg.as_name + elif isinstance(arg, SymFloatArgument): + name = arg.as_name + else: + raise AssertionError( + f"generate_getitem got unknown argument type {type(arg)}" + ) + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name=name, + ) + self.sync_fx_node(name, individual_output) + meta_val.append(self.serialized_name_to_meta[name]) + # The derived `getitem` nodes should have the same stacktrace as the + # original `fx_node` + individual_output.meta.update(deserialized_metadata) + + def generate_getitems( + self, + meta_val, + fx_node: torch.fx.Node, + args, + deserialized_metadata: dict[str, Any], + ): + for idx, arg in enumerate(args): + if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)): + self.generate_getitem( + meta_val, fx_node, arg, idx, deserialized_metadata + ) + continue + + assert isinstance(arg, Argument) + if arg.type in ("as_tensor", "as_sym_int", "as_sym_float"): + self.generate_getitem( + meta_val, fx_node, arg.value, idx, deserialized_metadata + ) + elif arg.type in ( + "as_tensors", + "as_sym_ints", + "as_sym_floats", + "as_ints", + "as_floats", + "as_strings", + "as_bools", + "as_sym_bools", + ): + list_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + ) + meta_val.append([]) + self.generate_getitems( + meta_val[-1], list_output, arg.value, deserialized_metadata + ) + list_output.meta.update(deserialized_metadata) + list_output.meta["val"] = meta_val[-1] + elif arg.type == "as_none": + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name="as_none", + ) + meta_val.append(None) + individual_output.meta["val"] = None + individual_output.meta.update(deserialized_metadata) + else: + raise NotImplementedError(f"Unimplemented node output type: {arg}") + + def deserialize_multiple_outputs( + self, serialized_node: Node, fx_node: torch.fx.Node + ) -> None: + deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) + + # Convert multiple return types to FX format. + # In FX, each node only returns one value. So in order to represent + # multiple return values, we have to emit a `getitem` node for each + # return value. + # This performs the inverse mapping of the `serialize_outputs` call in + # serialization, see [NOTE: Multiple outputs] + meta_val: list[Any] = [] + if len(serialized_node.outputs) == 1: + assert isinstance(serialized_node.outputs[0].value, list) + assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) + self.generate_getitems( + meta_val, + fx_node, + serialized_node.outputs[0].as_tensors, + deserialized_metadata, + ) + else: + self.generate_getitems( + meta_val, fx_node, serialized_node.outputs, deserialized_metadata + ) + + # also update the metaval for `fx_node` to be a list(meta) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + + def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: + ret: dict[str, Any] = {} + if stack_trace := metadata.get("stack_trace"): + ret["stack_trace"] = stack_trace + + def deserialize_meta_func(serialized_target: str): + module = None + if serialized_target.startswith("torch.nn"): + module = torch.nn + serialized_target_names = serialized_target.split(".")[2:] + elif serialized_target.startswith("torch"): + module = torch + serialized_target_names = serialized_target.split(".")[1:] + else: + return self.deserialize_operator(serialized_target) + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + if nn_module_stack_str := metadata.get("nn_module_stack"): + # Originally serialized to "key,orig_path,type_str" + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + + # Helper function to split string by commas, accounting for nested parentheses/brackets + def metadata_split(metadata): + out = [] + start, n = 0, 0 + a, b = "[(", ")]" + for end, c in enumerate(metadata): + if c in a: + n += 1 + elif c in b: + n -= 1 + elif c == "," and n == 0: + out.append(metadata[start:end]) + start = end + 1 + out.append(metadata[start:]) + assert len(out) == 3 + return out + + nn_module_stack = dict( + import_nn_module_stack(*metadata_split(item)) + for item in nn_module_stack_str.split(ST_DELIMITER) + ) + ret["nn_module_stack"] = nn_module_stack + + if source_fn_st_str := metadata.get("source_fn_stack"): + # Originally serializes to "fx_node_name,op_str" + source_fn_st = [] + for source_fn_str in source_fn_st_str.split(ST_DELIMITER): + name, target_str = source_fn_str.split(",") + source_fn_st.append((name, deserialize_meta_func(target_str))) + ret["source_fn_stack"] = source_fn_st + + if torch_fn_str := metadata.get("torch_fn"): + ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER)) + + if custom_str := metadata.get("custom"): + ret["custom"] = json.loads(custom_str) + + return ret + + def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: + log.debug("[deserialize_argument_spec] %s", x) + if x.type == "as_tensor": + return ep.TensorArgument(name=x.as_tensor.name) + elif x.type == "as_sym_int": + return ep.SymIntArgument(name=x.as_sym_int.as_name) + elif x.type == "as_sym_float": + return ep.SymFloatArgument(name=x.as_sym_float.as_name) + elif x.type == "as_custom_obj": + return ep.ConstantArgument( + name=x.as_custom_obj.name, value=self.deserialize_input(x) + ) + else: + return ep.ConstantArgument(name="", value=self.deserialize_input(x)) + + def deserialize_module_call_signature( + self, module_call_signature: ModuleCallSignature + ) -> ep.ModuleCallSignature: + return ep.ModuleCallSignature( + inputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=treespec_loads(module_call_signature.in_spec), + out_spec=treespec_loads(module_call_signature.out_spec), + forward_arg_names=names + if (names := module_call_signature.forward_arg_names) + else None, + ) + + def deserialize_module_call_graph( + self, module_call_graph: list[ModuleCallEntry] + ) -> list[ep.ModuleCallEntry]: + log.debug("\n[deserialize_module_call_graph]") + return [ + ep.ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.deserialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + +@final +class ExportedProgramDeserializer(metaclass=Final): + def __init__(self, expected_opset_version: Optional[dict[str, int]] = None): + self.expected_opset_version: dict[str, int] = {} + if expected_opset_version: + self.expected_opset_version.update(expected_opset_version) + if "aten" not in self.expected_opset_version: + self.expected_opset_version["aten"] = torch._C._get_max_operator_version() + + def deserialize_range_constraints( + self, + symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges], + symbol_name_to_symbol: dict[str, sympy.Symbol], + ) -> dict[sympy.Symbol, ValueRanges]: + log.debug("\n[deserialize_range_constraints]") + range_constraints = {} + for k, v in symbol_name_to_range.items(): + if symbol := symbol_name_to_symbol.get(k): + log.debug("[deserialize_range_constraints] %s -> %s", k, v) + range_constraints[symbol] = v # type: ignore[arg-type] + else: + log.warning( + "Symbol %s did not appear in the graph that was deserialized", k + ) + return range_constraints + + def deserialize( + self, + exported_program: ExportedProgram, + state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, torch.Tensor], bytes], + example_inputs: Optional[ + Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes] + ] = None, + *, + _unsafe_skip_version_check=False, + ) -> ep.ExportedProgram: + assert isinstance(exported_program, ExportedProgram) + version = exported_program.schema_version + + # TODO(zhxchen17) blocked on thrift schema refactor + if version.major != SCHEMA_VERSION[0] and not ( + version.major == 0 and version.minor == 0 + ): + if not _unsafe_skip_version_check: + raise SerializeError( + f"Serialized schema version {exported_program.schema_version} " + f"does not match our current schema version {SCHEMA_VERSION}." + ) + + symbol_name_to_range = { + k: symbolic_shapes.ValueRanges( + _int_to_sympy_int(v.min_val, -int_oo), + _int_to_sympy_int(v.max_val, int_oo), + ) + for k, v in exported_program.range_constraints.items() + } + res = GraphModuleDeserializer().deserialize( + exported_program.graph_module, + state_dict, + constants, + example_inputs, + symbol_name_to_range, + ) + range_constraints = self.deserialize_range_constraints( + symbol_name_to_range, + res.names_to_symbols, + ) + + result = ep.ExportedProgram( + root=res.graph_module, + graph=res.graph_module.graph, + graph_signature=res.signature, + state_dict=res.state_dict, # type: ignore[arg-type] + range_constraints=range_constraints, + module_call_graph=res.module_call_graph, + example_inputs=res.example_inputs, + constants=res.constants, + verifiers=[load_verifier(v) for v in exported_program.verifiers], + ) + result._guards_code = exported_program.guards_code + log.debug("\n[deserialize]: %s", result) + return result + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + return super().default(obj) + + +def _dataclass_to_dict(obj): + if isinstance(obj, _Union): + return {obj.type: _dataclass_to_dict(obj.value)} + elif dataclasses.is_dataclass(obj): + return { + f.name: _dataclass_to_dict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + } + elif isinstance(obj, list): + return [_dataclass_to_dict(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_dataclass_to_dict(x) for x in obj) + elif isinstance(obj, dict): + return {k: _dataclass_to_dict(v) for k, v in obj.items()} + elif isinstance(obj, float): + if obj == math.inf: + return "Infinity" + elif obj == -math.inf: + return "-Infinity" + elif math.isnan(obj): + return "NaN" + else: + return obj + else: + return obj + + +def _to_json_bytes(obj: Any) -> bytes: + return json.dumps(_dataclass_to_dict(obj), cls=EnumEncoder, allow_nan=False).encode( + "utf-8" + ) + + +def serialize( + exported_program: ep.ExportedProgram, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> SerializedArtifact: + with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs): + serialized_program = ExportedProgramSerializer( + opset_version, pickle_protocol + ).serialize(exported_program) + assert isinstance(serialized_program.exported_program, ExportedProgram) + + json_bytes = _to_json_bytes(serialized_program.exported_program) + artifact = SerializedArtifact( + json_bytes, + serialized_program.state_dict, + serialized_program.constants, + serialized_program.example_inputs, + ) + return artifact + + +def _resolve_schema_cls(cls): + if isinstance(cls, str): + resolved = getattr(schema, cls, None) + if resolved is not None: + return resolved + if isinstance(cls, typing.ForwardRef): + return _resolve_schema_cls(cls.__forward_arg__) + return cls + + +def _dict_to_dataclass(cls, data): + cls = _resolve_schema_cls(cls) + assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." + if typing.get_origin(cls) is Annotated: + return _dict_to_dataclass(cls.__origin__, data) + if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): + if data is None: + return None + ty_args = typing.get_args(cls) + assert len(ty_args) == 2 + return _dict_to_dataclass(ty_args[0], data) + elif isinstance(cls, type) and issubclass(cls, _Union): + assert isinstance(data, dict) + assert len(data) == 1 + _type = next(iter(data.keys())) + _value = next(iter(data.values())) + assert isinstance(_type, str) + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) + field_type = type_hints[_type] + # pyrefly: ignore [missing-attribute] + return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) + elif dataclasses.is_dataclass(cls): + fields = {} + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) + # For forward compatibility consideration, we ignore all the keys + # that are not showing up in the dataclass definition. + for f in dataclasses.fields(cls): + name = f.name + if name not in data: + continue + new_field_obj = _dict_to_dataclass(type_hints[name], data[name]) + fields[name] = new_field_obj + return cls(**fields) # type: ignore[operator] + elif isinstance(data, list): + if len(data) == 0: + return data + d_type = typing.get_args(cls)[0] + return [_dict_to_dataclass(d_type, d) for d in data] + elif isinstance(data, dict): + v_type = typing.get_args(cls)[1] + return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()} + elif cls is float: + return float(data) + return data + + +def _bytes_to_dataclass(cls: Any, artifact_bytes: bytes) -> Any: + artifact_str = artifact_bytes.decode("utf-8") + artifact_dict = json.loads(artifact_str) + artifact_dataclass = _dict_to_dataclass(cls, artifact_dict) + return artifact_dataclass + + +def deserialize( + artifact: SerializedArtifact, + expected_opset_version: Optional[dict[str, int]] = None, + *, + _unsafe_skip_version_check=False, +) -> ep.ExportedProgram: + assert isinstance(artifact.exported_program, bytes) + serialized_exported_program = _bytes_to_dataclass( + ExportedProgram, artifact.exported_program + ) + return ExportedProgramDeserializer(expected_opset_version).deserialize( + serialized_exported_program, + artifact.state_dict, + artifact.constants, + artifact.example_inputs, + _unsafe_skip_version_check=_unsafe_skip_version_check, + ) + + +def _canonicalize_graph( + sorted_inputs, sorted_outputs, graph, constants +) -> tuple[Graph, dict[str, str]]: + def _get_argument(a: Argument): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return a.as_tensor + elif a.type == "as_tensors": + return a.as_tensors + elif a.type == "as_int": + return None + elif a.type == "as_ints": + return None + elif a.type == "as_float": + return None + elif a.type == "as_floats": + return None + elif a.type == "as_string": + return None + elif a.type == "as_strings": + return None + elif a.type == "as_complex": + return None + elif a.type == "as_sym_int": + return a.as_sym_int + elif a.type == "as_sym_ints": + return a.as_sym_ints + elif a.type == "as_sym_float": + return a.as_sym_float + elif a.type == "as_sym_floats": + return a.as_sym_floats + elif a.type == "as_scalar_type": + return None + elif a.type == "as_memory_format": + return None + elif a.type == "as_layout": + return None + elif a.type == "as_device": + return None + elif a.type == "as_bool": + return None + elif a.type == "as_bools": + return None + elif a.type == "as_sym_bool": + return a.as_sym_bool + elif a.type == "as_sym_bools": + return a.as_sym_bools + elif a.type == "as_graph": + return None + elif a.type == "as_optional_tensors": + return a.as_optional_tensors + elif a.type == "as_custom_obj": + return a.as_custom_obj + elif a.type == "as_operator": + return None + elif a.type == "as_int_lists": + return None + elif a.type == "as_string_to_argument": + return None + else: + raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") + + # Stage 1: Reorder named items. + def for_args(f, a): + assert isinstance(a, Argument) + pytree.tree_map(f, _get_argument(a)) + + def sort_nodes(nodes): + @dataclass + class Edges: + outs: list[int] + ins: int + + graph_inputs: set[str] = set() + def_table: dict[str, int] = {} + edges: dict[int, Edges] = {} + candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = [] + rank: dict[str, int] = {} + ret: list[Node] = [] + + def get_name(a) -> Optional[str]: + if a is None: + return None + if isinstance(a, TensorArgument): + return a.name + elif isinstance(a, (SymIntArgument, SymBoolArgument, SymFloatArgument)): + if a.type == "as_name": + return a.as_name + elif a.type in ("as_int", "as_bool", "as_float"): + return None + else: + raise AssertionError(f"Unknown argument type: {a}") + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + return a.as_tensor.name + elif a.type == "as_none": + return None + else: + raise AssertionError(f"Unknown optional tensor type: {a}") + elif isinstance(a, CustomObjArgument): + return a.name + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + + def add_input(a): + if s := get_name(a): + graph_inputs.add(s) + + for_args(add_input, i) + + for idx, node in enumerate(nodes): + + def add_def(a): + if s := get_name(a): + assert s not in def_table + def_table[s] = idx + + for o in node.outputs: + for_args(add_def, o) + + edges[idx] = Edges([], 0) + + for idx, user in enumerate(nodes): + + def add_edge(a): + if s := get_name(a): + if s in constants: + return + if s not in def_table: + assert s in graph_inputs + return + src = def_table[s] + edges[src].outs.append(idx) + edges[idx].ins += 1 + + for i in user.inputs: + for_args(add_edge, i.arg) + + def add_rank(a): + if s := get_name(a): + assert s not in rank + rank[s] = len(rank) + + def get_rank(a): + s = get_name(a) + if s and s not in constants: + return rank[s] + else: + return -1 + + for i in sorted_inputs: + for_args(add_rank, i) + + def add_candidate(idx: int): + def get_ranks(i): + ranks = [] + for_args(lambda x: ranks.append(get_rank(x)), i) + return ranks + + node = nodes[idx] + args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] + heapq.heappush(candidates, (node.target, args_rank, idx)) + + for idx, e in edges.items(): + if e.ins == 0: + add_candidate(idx) + + while len(candidates) > 0: + _, _, idx = heapq.heappop(candidates) + node = nodes[idx] + for o in node.outputs: + for_args(add_rank, o) + ret.append(node) + assert idx in edges + for user in edges[idx].outs: + e = edges[user] + assert e.ins > 0 + e.ins -= 1 + if e.ins == 0: + add_candidate(user) + edges[idx].outs.clear() + + return ret + + sorted_nodes = sort_nodes(graph.nodes) + assert len(sorted_nodes) == len(graph.nodes) + + # Stage 2: Rename nodes. + name_table: dict[str, str] = {} + + def rename_def(a): + def _rename(arg_name, values): + new_name = f"_{len(name_table)}" + assert arg_name not in name_table + name_table[arg_name] = new_name + assert arg_name in values + values[new_name] = values.pop(arg_name) + return new_name + + if a is None: + return + if isinstance(a, TensorArgument): + a.name = _rename(a.name, graph.tensor_values) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_int_values) + elif isinstance(a, SymFloatArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_float_values) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_bool_values) + elif isinstance(a, CustomObjArgument): + a.name = _rename(a.name, graph.custom_obj_values) + else: + raise AssertionError(f"Unknown argument type: {a}") + + def replace_use(a): + if a is None: + return + if isinstance(a, TensorArgument): + a.name = name_table.get(a.name, a.name) + elif isinstance(a, (SymIntArgument, SymFloatArgument)): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name) + elif isinstance(a, CustomObjArgument): + a.name = name_table.get(a.name, a.name) + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + for_args(rename_def, i) + + for n in sorted_nodes: + for o in n.outputs: + for_args(rename_def, o) + + for n in sorted_nodes: + for i in n.inputs: + for_args(replace_use, i.arg) + + for o in sorted_outputs: + for_args(replace_use, o) + + # Stage 3: Remove unstable fields. + for n in sorted_nodes: + n.metadata.clear() + + # Stage 4: Aggregate values. + # pyrefly: ignore [no-matching-overload] + sorted_tensor_values = dict( + sorted(graph.tensor_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_sym_int_values = dict( + sorted(graph.sym_int_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_sym_float_values = dict( + sorted(graph.sym_float_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_sym_bool_values = dict( + sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_custom_obj_values = dict( + sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0)) + ) + + # Stage 5: Recurse in subgraphs. + counter = 0 + for node in sorted_nodes: + for i in node.inputs: + a = i.arg + if a.type == "as_graph": + a.as_graph.graph, _ = _canonicalize_graph( + a.as_graph.graph.inputs, + a.as_graph.graph.outputs, + a.as_graph.graph, + constants, + ) + a.as_graph.name = f"_g{counter}" + counter += 1 + + graph = Graph( + inputs=sorted_inputs, + outputs=sorted_outputs, + nodes=sorted_nodes, + tensor_values=sorted_tensor_values, + sym_int_values=sorted_sym_int_values, + sym_float_values=sorted_sym_float_values, + sym_bool_values=sorted_sym_bool_values, + is_single_tensor_return=graph.is_single_tensor_return, + custom_obj_values=sorted_custom_obj_values, + ) + return graph, name_table + + +def canonicalize( + ep: ExportedProgram, constants: Optional[set[str]] = None +) -> ExportedProgram: + """ + Normalize a serialized ExportedProgram, so that different eager program which + shares the same semantics can get a single representation on disk. + + This function canonicalizes an ExportedProgram by: + + 1. Sorting nodes in topological order. + 2. Rename nodes to have unique names. + 3. Remove unstable fields. + 4. Aggregate the above program fields. + 5. Recurse in subgraphs. + + Args: + ep (ExportedProgram): The ExportedProgram to canonicalize. + constants (Optional[set[str]]): Set of constants names + + Returns: + ExportedProgram: The canonicalized exported program. + """ + ep = copy.deepcopy(ep) + # pyrefly: ignore [annotation-mismatch] + constants: set[str] = constants or set() + + opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0))) + range_constraints = dict( + sorted(ep.range_constraints.items(), key=operator.itemgetter(0)) + ) + guards_code = sorted(ep.guards_code) + module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) + signature = ep.graph_module.signature + graph = ep.graph_module.graph + + assert len(graph.inputs) == len(signature.input_specs) + assert len(graph.outputs) == len(signature.output_specs) + + def rank_input(inp) -> tuple[int, Optional[str], int]: + idx, (_arg, spec) = inp + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + return 5, None, idx + elif spec.type == "parameter": + return 1, spec.parameter.parameter_name, idx + elif spec.type == "buffer": + return 2, spec.buffer.buffer_name, idx + elif spec.type == "tensor_constant": + return 3, spec.tensor_constant.tensor_constant_name, idx + elif spec.type == "custom_obj": + return 4, spec.custom_obj.custom_obj_name, idx + elif spec.type == "token": + return 0, None, idx + elif spec.type == "constant_input": + return 6, spec.constant_input.name, idx + else: + raise AssertionError(f"Unknown input type: {spec}") + + def rank_output(out) -> tuple[int, Optional[str], int]: + idx, (_arg, spec) = out + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + return 4, None, idx + elif spec.type == "loss_output": + return 4, None, idx + elif spec.type == "parameter_mutation": + return 1, spec.parameter_mutation.parameter_name, idx + elif spec.type == "buffer_mutation": + return 2, spec.buffer_mutation.buffer_name, idx + elif spec.type == "gradient_to_parameter": + return 5, spec.gradient_to_parameter.parameter_name, idx + elif spec.type == "gradient_to_user_input": + return 6, None, idx + elif spec.type == "user_input_mutation": + return 3, None, idx + elif spec.type == "token": + return 0, None, idx + else: + raise AssertionError(f"Unknown output type: {spec}") + + sorted_ins = sorted( + enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input + ) + + if len(sorted_ins) > 0: + sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] + else: + sorted_inputs = () + input_specs = () + + sorted_outs = sorted( + enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output + ) + sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] + + sorted_graph, replace_table = _canonicalize_graph( + sorted_inputs, sorted_outputs, graph, constants + ) + + def replace_input(spec): + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + arg = spec.user_input.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type == "as_sym_float": + f = arg.as_sym_float + if f.type == "as_name": + f.as_name = replace_table[f.as_name] + elif f.type == "as_float": + pass + else: + raise AssertionError(f"Unknown sym_float type: {f}") + elif arg.type in ( + "as_none", + "as_bool", + "as_int", + "as_float", + "as_string", + "as_custom_obj", + ): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "parameter": + t = spec.parameter.arg + t.name = replace_table[t.name] + elif spec.type == "buffer": + t = spec.buffer.arg + t.name = replace_table[t.name] + elif spec.type == "tensor_constant": + t = spec.tensor_constant.arg + t.name = replace_table[t.name] + elif spec.type == "custom_obj": + t_custom_obj = spec.custom_obj.arg + t_custom_obj.name = replace_table[t_custom_obj.name] + return + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + elif spec.type == "constant_input": + return + else: + raise AssertionError(f"Unknown input type: {spec}") + + def replace_output(out): + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + arg = spec.user_output.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type == "as_sym_float": + f = arg.as_sym_float + if f.type == "as_name": + f.as_name = replace_table[f.as_name] + elif f.type == "as_float": + pass + else: + raise AssertionError(f"Unknown sym_float type: {f}") + elif arg.type in ("as_none", "as_bool", "as_int", "as_float", "as_string"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "loss_output": + t = spec.loss_output.arg + t.name = replace_table[t.name] + elif spec.type == "buffer_mutation": + t = spec.buffer_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "parameter_mutation": + t = spec.parameter_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_parameter": + t = spec.gradient_to_parameter.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_user_input": + g = spec.gradient_to_user_input + g.arg.name = replace_table[g.arg.name] + g.user_input_name = replace_table[g.user_input_name] + elif spec.type == "user_input_mutation": + u = spec.user_input_mutation + u.arg.name = replace_table[u.arg.name] + u.user_input_name = replace_table[u.user_input_name] + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + else: + raise AssertionError(f"Unknown output type: {spec}") + + for spec in input_specs: + replace_input(spec) + + for spec in output_specs: + replace_output(spec) + + return ExportedProgram( + graph_module=GraphModule( + graph=sorted_graph, + signature=GraphSignature( + input_specs=list(input_specs), + output_specs=list(output_specs), + ), + module_call_graph=module_call_graph, + ), + opset_version=opset_version, + range_constraints=range_constraints, + schema_version=ep.schema_version, + verifiers=ep.verifiers, + torch_version=ep.torch_version, + guards_code=guards_code, + ) + + +class ExtensionHandler: + """ + Base class for handling extension operators. + """ + + @classmethod + def namespace(cls) -> str: + raise NotImplementedError(f"{cls.__class__} namespace() must be implemented") + + @classmethod + def to_op_name(cls, op) -> str: + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def from_op_name(cls, name: str): + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def op_schema(cls, op) -> torch.FunctionSchema: + raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented") + + +def register_extension( + op_type: type[Any], + extension_handler: type[ExtensionHandler], +): + """Register custom de/serialization method for a node with non-standard type.""" + assert issubclass(extension_handler, ExtensionHandler), ( + f"Expected ExtensionHandler, got {extension_handler}." + ) + assert op_type not in _serialization_registry, f"{op_type} is already registered." + assert isinstance(op_type, type) # Maybe a good idea to enforce this first. + assert not ( + op_type.__module__.startswith("torch") + or op_type.__module__.startswith("builtins") + ) + assert extension_handler.namespace() not in _deserialization_registry + _serialization_registry[op_type] = extension_handler + _deserialization_registry[extension_handler.namespace()] = extension_handler + + +def _registered_extension_types(): + return tuple(_serialization_registry.keys()) + + +# Registry to store all custom serialization implementations. +# The registry maps a operation to its serialization function (a callable), in their own +# namespace to avoid conflicts. +# Serialization: Op type --> custom handler. +# De-serialization: Namespace --> custom handler. +_serialization_registry: dict[type[Any], type[ExtensionHandler]] = {} +_deserialization_registry: dict[str, type[ExtensionHandler]] = {} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/union.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/union.py new file mode 100644 index 0000000000000000000000000000000000000000..c65ad38d337fea7631c122003e263a94cc4870dc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/union.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Hashable +from dataclasses import dataclass, fields +from typing import TypeVar +from typing_extensions import dataclass_transform + + +T = TypeVar("T", bound="_Union") + + +class _UnionTag(str): + __slots__ = ("_cls",) + _cls: Hashable + + @staticmethod + def create(t, cls): + tag = _UnionTag(t) + assert not hasattr(tag, "_cls") + tag._cls = cls + return tag + + def __eq__(self, cmp) -> bool: + assert isinstance(cmp, str) + other = str(cmp) + assert other in _get_field_names(self._cls), ( + f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + ) + return str(self) == other + + def __hash__(self): + return hash(str(self)) + + +@functools.cache +def _get_field_names(cls) -> set[str]: + return {f.name for f in fields(cls)} + + +# If you turn a schema class that inherits from union into a dataclass, please use +# this decorator to configure it. It's safe, faster and allows code sharing. +# +# For example, _union_dataclass customizes the __eq__ method to only check the type +# and value property instead of default implementation of dataclass which goes +# through every field in the dataclass. +@dataclass_transform(eq_default=False) +def _union_dataclass(cls: type[T]) -> type[T]: + assert issubclass(cls, _Union), f"{cls} must inheirt from {_Union}." + return dataclass(repr=False, eq=False)(cls) + + +class _Union: + _type: _UnionTag + + @classmethod + def create(cls, **kwargs): + assert len(kwargs) == 1 + obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] + obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) + return obj + + def __post_init__(self): + assert not any( + f.name in ("type", "_type", "create", "value") + for f in fields(self) # type: ignore[arg-type, misc] + ) + + @property + def type(self) -> str: + try: + return self._type + except AttributeError as e: + raise RuntimeError( + f"Please use {type(self).__name__}.create to instantiate the union type." + ) from e + + @property + def value(self): + return getattr(self, self.type) + + def __getattribute__(self, name): + attr = super().__getattribute__(name) + if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] + raise AttributeError(f"Field {name} is not set.") + return attr + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Union): + return False + return self.type == other.type and self.value == other.value + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b629d43ef3b5d9734cf2fc6bf1502026d30c0c30 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py @@ -0,0 +1,190 @@ +import json +import logging +from typing import Any + +from torch._logging import trace_structured +from torch.fx import Graph, Node + + +log: logging.Logger = logging.getLogger(__name__) + + +def create_joint_graph_node_information( + joint_graph: Graph, + recomputable_node_info: dict[str, int], +) -> dict[str, Any]: + joint_graph_node_information: dict[str, Any] = {} + + for i, joint_graph_node in enumerate(joint_graph.nodes): + is_recomputable_candidate: bool = ( + joint_graph_node.name in recomputable_node_info + ) + tensor_meta = joint_graph_node.meta.get("tensor_meta") + shape = getattr(tensor_meta, "shape", []) if tensor_meta else [] + + node_info: dict[str, Any] = { + "index": i, + "name": joint_graph_node.name, + "is_recomputable_candidate": is_recomputable_candidate, + "target": str(joint_graph_node.target), + "shape": str(shape), + "input_arguments": [inp.name for inp in joint_graph_node.all_input_nodes], + "stack_trace": joint_graph_node.meta.get("stack_trace", ""), + } + + if is_recomputable_candidate: + idx: int = recomputable_node_info[joint_graph_node.name] + node_info["recomputable_candidate_info"] = { + "recomputable_node_idx": idx, + } + + joint_graph_node_information[joint_graph_node.name] = node_info + + return joint_graph_node_information + + +def create_joint_graph_edges(joint_graph: Graph) -> list[tuple[str, str]]: + joint_graph_edges: list[tuple[str, str]] = [ + (inp.name, node.name) + for node in joint_graph.nodes + for inp in node.all_input_nodes + ] + return joint_graph_edges + + +def create_activation_checkpointing_logging_structure_payload( + joint_graph: Graph, + joint_graph_node_information: dict[str, Any], + joint_graph_edges: list[tuple[str, str]], + all_recomputable_banned_nodes: list[Node], + expected_runtime: float, + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], + memories_banned_nodes: list[int], + normalized_memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], +) -> dict[str, Any]: + """ + Creates a structured payload for logging activation checkpointing information. + + Args: + joint_graph: The computational graph representing operations. + joint_graph_node_information: Dictionary containing information about nodes in the joint graph. + joint_graph_edges: List of edges in the joint graph represented as tuples of node names. + all_recomputable_banned_nodes: List of nodes that are banned from recomputation. + expected_runtime: Expected runtime of the computation. + saved_node_idxs: Indices of nodes that are saved (not recomputed). + recomputable_node_idxs: Indices of nodes that can be recomputed. + memories_banned_nodes: Memory usage values (in absolute units) for banned nodes. + normalized_memories_banned_nodes: Normalized memory usage values for banned nodes, + used as input to the knapsack algorithm. + runtimes_banned_nodes: Runtime values for banned nodes, used as input to the + knapsack algorithm. + min_cut_saved_values: List of nodes saved by the min-cut algorithm. + + Returns: + A dictionary containing structured logging information for activation checkpointing. + """ + activation_checkpointing_logging_structure_payload: dict[str, Any] = { + "Joint Graph Size": len(joint_graph.nodes), + "Joint Graph Edges": { + "Total": len(joint_graph_edges), + "Edges": joint_graph_edges, + }, + "Joint Graph Node Information": joint_graph_node_information, + "Recomputable Banned Nodes Order": [ + node.name for node in all_recomputable_banned_nodes + ], + "Expected Runtime": expected_runtime, + "Knapsack Saved Nodes": saved_node_idxs, + "Knapsack Recomputed Nodes": recomputable_node_idxs, + "Absolute Memories": memories_banned_nodes, + "Knapsack Input Memories": normalized_memories_banned_nodes, + "Knapsack Input Runtimes": runtimes_banned_nodes, + "Min Cut Solution Saved Values": [node.name for node in min_cut_saved_values], + } + return activation_checkpointing_logging_structure_payload + + +def create_structured_trace_for_min_cut_info( + joint_graph: Graph, + all_recomputable_banned_nodes: list[Node], + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], + expected_runtime: float, + memories_banned_nodes: list[int], + normalized_memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], +) -> None: + """ + Creates a structured trace for minimum cut information in the graph. + + Args: + joint_graph: The computational graph representation. + all_recomputable_banned_nodes: List of nodes that can be recomputed. + saved_node_idxs: Indices of nodes that are saved in memory. + recomputable_node_idxs: Indices of nodes that are recomputed. + expected_runtime: Expected runtime for the computation. + memories_banned_nodes: Memory requirements for each banned node in bytes. + normalized_memories_banned_nodes: Normalized memory requirements for each banned node + (typically scaled between 0 and 1 for relative comparison). + runtimes_banned_nodes: Runtime costs associated with each banned node. + min_cut_saved_values: Nodes that are saved as part of the minimum cut solution. + """ + # Create a dictionary to store recomputable node information + recomputable_node_info: dict[str, int] = { + node.name: idx for idx, node in enumerate(all_recomputable_banned_nodes) + } + + # Create joint graph node information + joint_graph_node_information = create_joint_graph_node_information( + joint_graph, recomputable_node_info + ) + + # Update node information with recomputable candidate details + for node_name, node_info in joint_graph_node_information.items(): + if node_info["is_recomputable_candidate"]: + idx = recomputable_node_info[node_name] + node_info["recomputable_candidate_info"]["memory"] = memories_banned_nodes[ + idx + ] + node_info["recomputable_candidate_info"]["runtime"] = runtimes_banned_nodes[ + idx + ] + node_info["recomputable_candidate_info"]["is_saved"] = ( + idx in saved_node_idxs + ) + node_info["recomputable_candidate_info"]["is_recomputed"] = ( + idx in recomputable_node_idxs + ) + + # Create joint graph edges + joint_graph_edges = create_joint_graph_edges(joint_graph) + + # Create activation checkpointing logging structure payload + activation_checkpointing_logging_structure_payload = ( + create_activation_checkpointing_logging_structure_payload( + joint_graph=joint_graph, + joint_graph_node_information=joint_graph_node_information, + joint_graph_edges=joint_graph_edges, + all_recomputable_banned_nodes=all_recomputable_banned_nodes, + expected_runtime=expected_runtime, + saved_node_idxs=saved_node_idxs, + recomputable_node_idxs=recomputable_node_idxs, + memories_banned_nodes=memories_banned_nodes, + normalized_memories_banned_nodes=normalized_memories_banned_nodes, + runtimes_banned_nodes=runtimes_banned_nodes, + min_cut_saved_values=min_cut_saved_values, + ) + ) + + # Create structured trace + trace_structured( + "artifact", + metadata_fn=lambda: {"name": "min_cut_information", "encoding": "json"}, + payload_fn=lambda: json.dumps( + activation_checkpointing_logging_structure_payload + ), + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5da58fdd63303bebddd2439f7b6607b45377d5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py @@ -0,0 +1,319 @@ +from typing import Any, Optional + +import networkx as nx + +from torch.fx import Graph, Node + + +class GraphInfoProvider: + """ + This class provides information about the graph, such as the nodes, edges, and their runtime and memory requirements. + It also provides methods to create graphs from the information provided. + """ + + __RECOMPUTABLE_NODE_ONLY_GRAPH = "recomputable_node_only_graph" + __RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT = ( + "recomputable_node_only_graph_with_larger_graph_context" + ) + __FULL_NX_JOINT_GRAPH = "full_nx_joint_graph" + __SIMPLIFIED_FX_JOINT_GRAPH = "fx_joint_graph" + + def __init__( + self, + graph_nodes_in_order: list[str], + graph_edges: list[tuple[str, str]], + all_recomputable_banned_nodes: list[str], + all_node_runtimes: Optional[dict[str, float]] = None, + all_node_memories: Optional[dict[str, float]] = None, + recorded_knapsack_input_memories: Optional[list[float]] = None, + recorded_knapsack_input_runtimes: Optional[list[float]] = None, + joint_graph: Optional[Graph] = None, + ): + self.graph_nodes_in_order = graph_nodes_in_order + self.graph_edges = graph_edges + self.all_node_runtimes: dict[str, float] = dict() + if all_node_runtimes is None: + if recorded_knapsack_input_runtimes is None: + raise ValueError( + "Either all_node_runtimes or recorded_knapsack_input_runtimes must be provided." + ) + self.all_node_runtimes = { + node: recorded_knapsack_input_runtimes[i] + for i, node in enumerate(all_recomputable_banned_nodes) + } + else: + self.all_node_runtimes.update(all_node_runtimes) + self.all_node_memories: dict[str, float] = dict() + if all_node_memories is None: + if recorded_knapsack_input_memories is None: + raise ValueError( + "Either all_node_memories or recorded_knapsack_input_memories must be provided." + ) + self.all_node_memories = { + node: recorded_knapsack_input_memories[i] + for i, node in enumerate(all_recomputable_banned_nodes) + } + else: + self.all_node_memories.update(all_node_memories) + self.all_recomputable_banned_nodes = all_recomputable_banned_nodes + self.all_recomputable_banned_nodes_set = set(all_recomputable_banned_nodes) + self.recorded_knapsack_input_memories = recorded_knapsack_input_memories + self.recorded_knapsack_input_runtimes = recorded_knapsack_input_runtimes + self._lazily_initialized_graphs: dict[str, Any] = { + self.__RECOMPUTABLE_NODE_ONLY_GRAPH: None, + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT: None, + self.__FULL_NX_JOINT_GRAPH: None, + self.__SIMPLIFIED_FX_JOINT_GRAPH: None, + } + + @classmethod + def inialize_from_graph( + cls, + joint_graph: Graph, + all_recomputable_banned_nodes: list[Node], + recorded_knapsack_input_memories: list[float], + recorded_knapsack_input_runtimes: list[float], + ) -> "GraphInfoProvider": + """ + Enables initialization from a joint graph. + """ + graph_nodes_in_order = [node.name for node in joint_graph.nodes] + graph_edges = [ + (node.name, user.name) for node in joint_graph.nodes for user in node.users + ] + all_recomputable_banned_node_names = [ + node.name for node in all_recomputable_banned_nodes + ] + return cls( + graph_nodes_in_order=graph_nodes_in_order, + graph_edges=graph_edges, + all_recomputable_banned_nodes=all_recomputable_banned_node_names, + recorded_knapsack_input_memories=recorded_knapsack_input_memories, + recorded_knapsack_input_runtimes=recorded_knapsack_input_runtimes, + joint_graph=joint_graph, + ) + + @property + def recomputable_node_only_graph(self) -> nx.DiGraph: + if self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] is None: + self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] = ( + self._create_recomputable_node_only_graph() + ) + return self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] + + @property + def recomputable_node_only_graph_with_larger_graph_context(self) -> nx.DiGraph: + if ( + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] + is None + ): + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] = self._create_recomputable_node_only_graph_with_larger_graph_context() + return self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] + + @property + def full_joint_nx_graph(self) -> nx.DiGraph: + if self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] is None: + self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] = ( + self._create_full_joint_graph() + ) + return self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] + + @property + def simplified_fx_joint_graph(self) -> Graph: + if self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] is None: + self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] = ( + self._recreate_psuedo_joint_graph() + ) + return self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] + + def get_non_ac_peak_memory(self) -> float: + return sum( + self.all_node_memories[node_name] + for node_name in self.all_recomputable_banned_nodes_set + ) + + def get_theoretical_max_runtime(self) -> float: + return sum( + self.all_node_runtimes[node_name] + for node_name in self.all_recomputable_banned_nodes_set + ) + + def get_knapsack_memory_input(self) -> list[float]: + return ( + self.recorded_knapsack_input_memories + if self.recorded_knapsack_input_memories + else [ + self.all_node_memories[node_name] + for node_name in self.all_recomputable_banned_nodes + ] + ) + + def get_knapsack_runtime_input(self) -> list[float]: + return ( + self.recorded_knapsack_input_runtimes + if self.recorded_knapsack_input_runtimes + else [ + self.all_node_runtimes[node_name] + for node_name in self.all_recomputable_banned_nodes + ] + ) + + def _create_recomputable_node_only_graph(self) -> nx.DiGraph: + graph = nx.DiGraph() + for recomputable_node in self.all_recomputable_banned_nodes: + graph.add_node(recomputable_node) + + for a, b in self.graph_edges: + if ( + a in self.all_recomputable_banned_nodes_set + and b in self.all_recomputable_banned_nodes_set + ): + graph.add_edge(a, b) + return graph + + def _create_recomputable_node_only_graph_with_larger_graph_context( + self, + ) -> nx.DiGraph: + # Create a dictionary to store the reachable nodes for each node + all_recomputable_banned_nodes_set = set(self.all_recomputable_banned_nodes) + + reachable_nodes = {} + for node in all_recomputable_banned_nodes_set: + # Use BFS to find all reachable nodes + predecessors = dict(nx.bfs_predecessors(self.full_joint_nx_graph, node)) + reachable_recomputable_nodes = set(predecessors.keys()).intersection( + all_recomputable_banned_nodes_set + ) + reachable_nodes[node] = reachable_recomputable_nodes + # Create the candidate graph + candidate_graph = nx.DiGraph() + candidate_graph.add_nodes_from(all_recomputable_banned_nodes_set) + for node1 in all_recomputable_banned_nodes_set: + for node2 in reachable_nodes[node1]: + # Check if there is an overlapping path + overlapping_path = False + for intermediate_node in reachable_nodes[node1]: + if ( + intermediate_node != node2 + and node2 in reachable_nodes[intermediate_node] + ): + overlapping_path = True + break + if not overlapping_path: + candidate_graph.add_edge(node1, node2) + return candidate_graph + + def _create_full_joint_graph(self) -> nx.DiGraph: + graph = nx.DiGraph() + for node in self.graph_nodes_in_order: + if node == "output": + continue + graph.add_node(node) + + for a, b in self.graph_edges: + if a == "output" or b == "output": + continue + graph.add_edge(a, b) + return graph + + def _recreate_psuedo_joint_graph(self) -> Graph: + # Create a dictionary to store the dependencies of each node + node_dependencies: dict[str, list[str]] = { + node: [] for node in self.graph_nodes_in_order + } + for a, b in self.graph_edges: + if a not in node_dependencies or b not in node_dependencies: + raise ValueError(f"Edge ({a}, {b}) references a non-existent node.") + node_dependencies[b].append(a) + + joint_graph = Graph() + # Create nodes in the graph + nodes: dict[str, Node] = {} + for node_name in self.graph_nodes_in_order: + input_nodes = [nodes[dep] for dep in node_dependencies[node_name]] + if input_nodes: + node = joint_graph.call_function(lambda *x: x, tuple(input_nodes)) + node.name = node_name + else: + node = joint_graph.placeholder(node_name) + nodes[node_name] = node + return joint_graph + + def _visualize_recomputable_candidate_graph_with_larger_context( + self, + layout_k: float = 0.5, + layout_iterations: int = 30, + ) -> None: + """ + Visualize the recomputable candidate graph with larger context. + """ + from matplotlib import cm, colors as mcolors, pyplot as plt + + pos = nx.spring_layout( + self.recomputable_node_only_graph_with_larger_graph_context, + k=layout_k, + iterations=layout_iterations, + ) + # pos = nx.spectral_layout(graph_with_indirect_edges) + plt.figure(figsize=(20, 15)) + + # Create a dictionary for node labels using the index + labels = { + node: self.recomputable_node_only_graph_with_larger_graph_context.nodes[ + node + ].get("index", node) + for node in self.recomputable_node_only_graph_with_larger_graph_context.nodes + } + + # Extract memory values and normalize them + norm = mcolors.Normalize( + vmin=min(self.get_knapsack_memory_input()), + vmax=max(self.get_knapsack_memory_input()), + ) + cmap = cm.viridis # type: ignore[attr-defined] + + # Assign colors based on memory + node_colors = [ + cmap( + norm( + float( + self.recomputable_node_only_graph_with_larger_graph_context.nodes[ + node + ]["memory"] + ) + ) + ) + for node in self.recomputable_node_only_graph_with_larger_graph_context.nodes + ] + + # Draw the graph with parsed nodes only + nx.draw_networkx_nodes( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + node_color=node_colors, + node_size=300, + label="Parsed Nodes", + ) + nx.draw_networkx_edges( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + arrows=True, + arrowsize=10, + ) + nx.draw_networkx_labels( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + labels=labels, + font_size=8, + font_weight="bold", + ) + + plt.title("Memory Colour Coded Dependency Graph for Recomputable Nodes") + plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), label="Memory") + plt.show() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f0a124c64c1ec7ec6651aa79ff62ebec557949 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py @@ -0,0 +1,267 @@ +import torch + + +def greedy_knapsack( + memory: list[float], runtimes: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + n = len(runtimes) + items = list(range(n)) + + # Sort items based on the ratio of runtime to memory in descending order + items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) + + total_memory = 0.0 + total_runtime = 0.0 + items_to_save = [] + items_to_allow_recomputing = [] + + for i in items: + if total_memory + memory[i] <= max_memory: + total_memory += memory[i] + total_runtime += runtimes[i] + items_to_save.append(i) + else: + items_to_allow_recomputing.append(i) + return total_runtime, items_to_save, items_to_allow_recomputing + + +def ilp_knapsack( + memory: list[float], runtimes: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + import numpy as np + + try: + from scipy.optimize import Bounds, LinearConstraint, milp + except ImportError: + raise RuntimeError( + "To use the ILP for memory budget checkpointing you need to install scipy" + ) from None + + np_memory = np.array(memory) + np_runtimes = np.array(runtimes) + c = -np_runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) + constraints = [memory_constraint] + + integrality = np.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise RuntimeError("Somehow scipy solving failed") + + items_to_save = [] + items_to_allow_recomputing = [] + for idx, i in enumerate(res.x): + if i == 1: + items_to_save.append(idx) + else: + items_to_allow_recomputing.append(idx) + return -res.fun, items_to_save, items_to_allow_recomputing + + +def dp_knapsack( + memory: list[float], runtime: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # Quantize the memory weights + quantized_memory = torch.tensor( + [round(m * S) for m in memory], dtype=torch.long, device="cpu" + ) + runtimes = torch.tensor(runtime, dtype=torch.float32, device="cpu") + + # Quantized pseudopolynomial DP for 0-1 Knapsack + quantized_max_memory = round(max_memory * S) + + n = len(memory) + + # Initialize the DP table + # TODO(chilli): I think if needed, this memory can be optimized with sliding + # window trick + Hirschberg trick: + # https://codeforces.com/blog/entry/47247?#comment-316200 + dp = torch.zeros( + (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" + ) + + for i in range(1, n + 1): + current_memory = quantized_memory[i - 1] + current_runtime = runtimes[i - 1] + + # Copy the previous row + dp[i, :] = dp[i - 1, :] + + # Update dp[i, j] for all j >= current_memory + if current_memory == 0: + dp[i, :] = dp[i - 1, :] + current_runtime + else: + dp[i, current_memory:] = torch.maximum( + dp[i - 1, current_memory:], + dp[i - 1, :-current_memory] + current_runtime, + ) + + # Backtrack to find the items included in the knapsack + saved_items = [] + recomputable_items = [] + j: int = quantized_max_memory + for i in range(n, 0, -1): + if dp[i][j] != dp[i - 1][j]: + saved_items.append(i - 1) # Include this item (indexing from 0) + j -= int(quantized_memory[i - 1].item()) + else: + recomputable_items.append(i - 1) + + saved_items.reverse() # To get items in the order they were added + + # The maximum runtime that can be achieved within the max_memory constraint + max_runtime = dp[n][quantized_max_memory].item() + + return max_runtime, saved_items, recomputable_items + + +def dp_knapsack_sliding_hirschberg( + memory: list[float], runtime: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # q_ prefix stands for quantized + q_memory = [int(round(m * S)) for m in memory] + runtimes = [float(v) for v in runtime] + + q_max_memory = int(round(max_memory * S)) + + q_memory_length = len(q_memory) + if q_memory_length == 0: + return 0.0, [], [] + + item_indices = list(range(q_memory_length)) + dp_profile_size = q_max_memory + 1 + + # Current DP profile (row) + dp_profile = torch.zeros(dp_profile_size, dtype=torch.float32, device="cpu") + # Store a candidate for next dp_profile - current dp row + item + candidate_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + left_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + right_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + + saved_items: list[int] = [] + recomputable_items: list[int] = [] + + # Explicit stack to optimize memory and avoid recursion + # Stack stores segments as (start index, end index, capacity for segment) + stack: list[tuple[int, int, int]] = [(0, q_memory_length, q_max_memory)] + + # LIFO + while stack: + start, end, capacity = stack.pop() + length = end - start + if length == 0: + continue + + # Leaf + if length == 1: + index = item_indices[start] + memory_item = q_memory[index] + runtime_item = runtimes[index] + if memory_item <= capacity and runtime_item > 0.0: + saved_items.append(index) + else: + recomputable_items.append(index) + continue + + # Split the segment into two halves + middle = start + (length // 2) + left_start, left_end = middle, end + right_start, right_end = start, middle + + # Assign items to both halves + left_items = item_indices[left_start:left_end] + right_items = item_indices[right_start:right_end] + + # Working only on items allowed by segment's capacity + capacity = capacity + 1 + dp_view = dp_profile[:capacity] + candidate_view = candidate_profile[:capacity] + left_dp_local = left_profile[:capacity] + right_dp_local = right_profile[:capacity] + + # Left part + dp_view.zero_() + for index in left_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + # Weight is 0, so add it to all capacities; a "free lunch", essentially + dp_view.add_(runtime_item) + continue + + # If item is too heavy, we skip it + if memory_item >= capacity: + continue + + # Add the current item so we can then pick the highest value + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + # Take the highest - either previous (without current) or with current + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the left profile + left_dp_local.copy_(dp_view) + + # Right part + dp_view.zero_() + for index in right_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + dp_view.add_(runtime_item) + continue + + if memory_item >= capacity: + continue + + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the reversed right profile + right_dp_local.copy_(dp_view.flip(-1)) + + # In-place compute item-wise sum of left and right to pick the split point where the sum is highest + left_dp_local.add_(right_dp_local) + + # Pick the index of highest value of a pair, which we then use as a split point + best_split = int(torch.argmax(left_dp_local).item()) + + left_capacity = best_split + right_capacity = capacity - best_split + + # Clamp (might be removed if we're 100% sure that there is no edge case that will mess up the indices math) + if left_capacity < 0: + left_capacity = 0 + if right_capacity < 0: + right_capacity = 0 + if left_capacity > q_max_memory: + left_capacity = q_max_memory + if right_capacity > q_max_memory: + right_capacity = q_max_memory + + # Push right then left, so left is processed next + stack.append((right_start, right_end, right_capacity)) + stack.append((left_start, left_end, left_capacity)) + + saved_items = sorted(saved_items) + recomputable_items = sorted(recomputable_items) + + max_runtime = sum(runtime[i] for i in saved_items) + recomputable_items.reverse() + return max_runtime, saved_items, recomputable_items diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1a3db275d2dc548e0edbebb632913d8fed01ec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py @@ -0,0 +1,273 @@ +import operator +from collections import deque +from collections.abc import Callable + +import networkx as nx + +from torch._functorch._activation_checkpointing.graph_info_provider import ( + GraphInfoProvider, +) + + +class KnapsackEvaluator: + """ + This class evaluates the theoretical runtime and peak memory usage of a given checkpointing strategy. + It takes in a graph and a list of nodes that are saved and recomputed, and then simulates the + backward pass to calculate the peak memory usage. + """ + + def __init__( + self, + graph_info_provider: GraphInfoProvider, + ) -> None: + self._graph_info_provider = graph_info_provider + + def _get_backward_memory_from_topologically_sorted_graph( + self, + node_graph: nx.DiGraph, + node_memories: dict[str, float], + saved_nodes_set: set[str], + peak_memory_after_forward_pass: float, + ) -> list[tuple[float, str]]: + """ + Simulates the backward pass and keeps track of the peak memory usage. + + High Level Steps: + 1. Set Initial Peak/Current Memory + Allows you to set the peak memory after the forward pass, but typically this is + the sum of the estimated memory of the saved nodes. + 2. Perform a reverse topological sort of the node_graph. + If full graph is defined then will sort the full graph and only process the subset + of nodes in the node_graph. + 3. Iterate through the sorted graph nodes. + If the node is saved then just drop it's memory from current memory. + If the node is not saved then add it's memory to current memory and then traverse it's + predecessors to simulate recomuptation chain. Will check if new peak memory after all + predecessors are processed. + + Args: + node_graph (nx.DiGraph): A directed graph representing the recomputable forward nodes. + saved_nodes_set (Set[str]): A set of node names that are saved. + peak_memory_after_forward_pass (float): The peak memory usage after the forward pass. + """ + current_memory = [ + (peak_memory_after_forward_pass, "Initial Peak/Current Memory") + ] + already_computed = set() + sorted_nodes = list(reversed(list(nx.topological_sort(node_graph)))) + dependencies_computed = set() + + for node in sorted_nodes: + if node in saved_nodes_set or node in already_computed: + current_memory.append( + ( + current_memory[-1][0] - node_memories[node], + f"Dropping Node(already saved): {node}", + ) + ) + continue + + already_computed.add(node) + current_memory.append( + ( + current_memory[-1][0] + node_memories[node], + f"Recomputing Node: {node}", + ) + ) + # Create a queue of dependencies required for recomputation + predecessor_queue = deque( + [ + dependency + for dependency, v in node_graph.in_edges(node) + if dependency not in already_computed + ] + ) + while predecessor_queue: + dep = predecessor_queue.popleft() + already_computed.add(dep) + dependencies_computed.add(dep) + current_memory.append( + ( + current_memory[-1][0] + node_memories[dep], + f"Recomputing Predecessor of {node}: {dep}", + ) + ) + # Add predecessors of the predecessor to the queue if they haven't been recomputed yet + for dependency_of_dependency, _ in node_graph.in_edges(dep): + if ( + dependency_of_dependency in already_computed + or dependency_of_dependency in saved_nodes_set + or dependency_of_dependency in predecessor_queue + ): + continue + predecessor_queue.append(dependency_of_dependency) + dependencies_computed.clear() + current_memory.append( + (current_memory[-1][0] - node_memories[node], f"Dropping Node: {node}") + ) + return current_memory + + def _validate_all_indexes_accounted_for_in_provided_output( + self, saved_nodes_idxs: list[int], recomputable_node_idxs: list[int] + ) -> None: + """ + Validate that all indexes are accounted for in the provided output. + This function checks that the union of saved nodes and recomputable nodes + covers all candidate nodes without any overlaps. + """ + recomputable_node_idxs_set = set(recomputable_node_idxs) + saved_nodes_idxs_set = set(saved_nodes_idxs) + all_candidate_nodes_idxs = set( + range(len(self._graph_info_provider.all_recomputable_banned_nodes)) + ) + # Check that there are no overlaps between saved nodes and recomputable nodes + assert ( + len(recomputable_node_idxs_set.intersection(saved_nodes_idxs_set)) == 0 + ), "Saved nodes and recomputable nodes cannot have any overlaps" + # Check that all candidate nodes are accounted for + assert ( + recomputable_node_idxs_set.union(saved_nodes_idxs_set) + == all_candidate_nodes_idxs + ), "All candidate nodes must be accounted for in the provided output" + + def evaluate_knapsack_output( + self, + saved_nodes_idxs: list[int], + recomputable_node_idxs: list[int], + account_for_backward_pass: bool = False, + ) -> dict[str, float]: + """ + Evaluate the theoretical runtime and peak memory usage of a given checkpointing strategy. + Args: + - saved_nodes_idxs (List[int]): The indices of nodes that are saved. + - recomputable_node_idxs (List[int]): The indices of nodes that need to be recomputed. + """ + self._validate_all_indexes_accounted_for_in_provided_output( + saved_nodes_idxs, recomputable_node_idxs + ) + recomputation_runtime = sum( + self._graph_info_provider.all_node_runtimes[ + self._graph_info_provider.all_recomputable_banned_nodes[node] + ] + for node in recomputable_node_idxs + ) + if account_for_backward_pass: + memory_list = self._get_backward_memory_from_topologically_sorted_graph( + node_graph=self._graph_info_provider.recomputable_node_only_graph_with_larger_graph_context, + saved_nodes_set={ + self._graph_info_provider.all_recomputable_banned_nodes[i] + for i in saved_nodes_idxs + }, + node_memories=self._graph_info_provider.all_node_memories, + peak_memory_after_forward_pass=sum( + self._graph_info_provider.all_node_memories[ + self._graph_info_provider.all_recomputable_banned_nodes[i] + ] + for i in saved_nodes_idxs + ), + ) + peak_memory = max(memory_list, key=operator.itemgetter(0))[0] + else: + peak_memory = sum( + self._graph_info_provider.all_node_memories[ + self._graph_info_provider.all_recomputable_banned_nodes[node] + ] + for node in saved_nodes_idxs + ) + return { + "peak_memory": peak_memory, + "recomputation_runtime": recomputation_runtime, + "non_ac_peak_memory": self._graph_info_provider.get_non_ac_peak_memory(), + "theoretical_max_runtime": self._graph_info_provider.get_theoretical_max_runtime(), + "percentage_of_theoretical_peak_memory": peak_memory + / self._graph_info_provider.get_non_ac_peak_memory(), + "percentage_of_theoretical_peak_runtime": recomputation_runtime + / self._graph_info_provider.get_theoretical_max_runtime(), + } + + def evaluate_distribution_of_results_for_knapsack_algo( + self, + knapsack_algo: Callable[ + [list[float], list[float], float], tuple[float, list[int], list[int]] + ], + memory_budget_values: list[float], + ) -> list[dict[str, float]]: + """ + Evaluates the distribution of results for a given knapsack algorithm. + Args: + knapsack_algo (Callable): The knapsack algorithm to use for evaluation. + memory_budget_values (List[float]): A list of memory budgets to evaluate. + """ + results = list() + for memory_budget in memory_budget_values: + _, saved_nodes, recomputed_nodes = knapsack_algo( + self._graph_info_provider.get_knapsack_memory_input(), + self._graph_info_provider.get_knapsack_runtime_input(), + memory_budget, + ) + result = self.evaluate_knapsack_output( + saved_nodes_idxs=saved_nodes, + recomputable_node_idxs=recomputed_nodes, + ) + result["memory_budget"] = memory_budget + results.append(result) + return results + + def get_knee_point_memory_budget( + self, + knapsack_algo: Callable[ + [list[float], list[float], float], tuple[float, list[int], list[int]] + ], + max_mem_budget: float = 0.1, + min_mem_budget: float = 0.001, + iterations: int = 100, + ) -> float: + """ + Finds the memory budget at the knee point in the Pareto frontier. + + The knee point is defined as the point where the trade-off between + runtime and memory usage is optimal. + + Args: + knapsack_algo (callable): Knapsack algorithm to use for evaluation. + max_mem_budget (float, optional): Maximum memory budget. Defaults to 0.1. + min_mem_budget (float, optional): Minimum memory budget. Defaults to 0.001. + iterations (int, optional): Number of memory budgets to evaluate. Defaults to 100. + + Returns: + float: Memory budget at the knee point. + """ + results = self.evaluate_distribution_of_results_for_knapsack_algo( + knapsack_algo=knapsack_algo, + memory_budget_values=[ + min_mem_budget + + i * (max_mem_budget - min_mem_budget) / (iterations - 1) + for i in range(iterations) + ], + ) + runtime_values = [ + result["percentage_of_theoretical_peak_runtime"] for result in results + ] + memory_values = [ + result["percentage_of_theoretical_peak_memory"] for result in results + ] + runtime_range = max(runtime_values) - min(runtime_values) + memory_range = max(memory_values) - min(memory_values) + if runtime_range == 0 or memory_range == 0: + return max_mem_budget + + # Normalize values + runtime_min = min(runtime_values) + memory_min = min(memory_values) + runtime_norm = [ + (value - runtime_min) / runtime_range for value in runtime_values + ] + memory_norm = [(value - memory_min) / memory_range for value in memory_values] + # Calculate Euclidean distance + distances = [ + (runtime_norm[i] ** 2 + memory_norm[i] ** 2) ** 0.5 + for i in range(len(runtime_norm)) + ] + # Find the knee point(shortest distance from the origin) + knee_index = distances.index(min(distances)) + return results[knee_index]["memory_budget"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f975bf0b5d111b0188d6ebc56e334eccb2a164fe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py @@ -0,0 +1,134 @@ +""" +AC rematerialize pass: Duplicates checkpointed nodes for backward, then DCE removes unused forward versions. +""" + +import warnings + +import torch +import torch.fx as fx +from torch._functorch import config +from torch._functorch.compile_utils import raise_getitems +from torch._functorch.partitioners import ( + cleanup_recompute_tags, + force_save_bw_mutation_src, + force_save_collectives, + has_recomputable_ops, + has_recomputable_rng_ops, + is_not_collective, + must_recompute, +) + + +def is_impure_node_for_dce(node): + # Check for special collectives that should be treated as pure + if not is_not_collective(node): + # It's a collective (wait_tensor, all_gather_into_tensor, etc.) + # Treat as pure - can be eliminated if unused + return False + + # For everything else, fall back to the DEFAULT logic + # This is what eliminate_dead_code() calls when is_impure_node=None + impure_random = True + if torch._guards.TracingContext.try_get(): + impure_random = torch._inductor.config.fallback_random + return node.is_impure(impure_random) + + +def _is_backward_node(node: fx.Node) -> bool: + """Check if node is in backward region via annotation""" + return node.meta.get("custom", {}).get("remat_pass_tag", None) == "is_backward" + + +def remat_using_tags_for_fwd_loss_bwd_graph(gm: fx.GraphModule) -> fx.GraphModule: + """ + Duplicate checkpointed nodes for backward use. DCE removes unused forward versions. We assume that + you already annotated your backward region with fx.traceback.annotate({"remat_pass_tag": "is_backward"}) + which helps us identify the backward region. + """ + if not has_recomputable_ops(gm): + return gm + + # Find backward boundary and build ordering + bwd_start: int | None = None + order = {} + for idx, node in enumerate(gm.graph.nodes): + order[node] = idx + if _is_backward_node(node) and bwd_start is None: + bwd_start = idx + + if bwd_start is None: + warnings.warn( + "remat_using_tags_for_fwd_loss_bwd_graph: Graph has recomputable ops but no backward region. " + "This may indicate a forward-only graph (e.g., from nested compilation) or missing backward annotations. " + "Returning graph unchanged." + ) + return gm + + if has_recomputable_rng_ops(gm): + raise RuntimeError( + "Activation checkpoint rematerializing in `forward-loss-backward` graph does not support RNG ops " + "in checkpointed regions. Please move RNG operations outside " + "of checkpoint regions, or use joint graph mode (where partitioner handles RNG)." + ) + + # Use partitioner pass to normalize AC node tags. + gm = cleanup_recompute_tags(gm, is_default_partition=True) + + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(gm) + + force_save_bw_mutation_src(gm) + + new_graph = fx.Graph() + env: dict[fx.Node, fx.Node] = {} + recomputed_nodes: dict[fx.Node, fx.Node] = {} + + # Insert forward nodes + for node in list(gm.graph.nodes)[:bwd_start]: + env[node] = new_graph.node_copy(node, lambda x: env[x]) + + def remat_input(x): + # fx.Node can have args that are primitive types (e.g. int, float, bool) + if not isinstance(x, fx.Node): + return x + return recomputed_nodes.get(x, env[x]) + + def gather_checkpointed_deps(node: fx.Node, visited: set) -> None: + if node in visited or node in recomputed_nodes: + return + visited.add(node) + for inp in node.all_input_nodes: + if must_recompute(inp): + gather_checkpointed_deps(inp, visited) + + # Insert backward nodes + for node in list(gm.graph.nodes)[bwd_start:]: + # Gather all checkpointed deps needed by this node + deps = set() + for inp in node.all_input_nodes: + if must_recompute(inp): + gather_checkpointed_deps(inp, deps) + + # Insert deps in forward order (guaranteed disjoint from already-inserted) + # This is not as inefficient as it looks, because we only add fresh dependencies + # when they are not yet processed as recomputed nodes. + for dep in sorted(deps, key=lambda n: order[n]): + assert dep not in recomputed_nodes, "We shouldn't have recomputed it before" + dup = new_graph.node_copy(dep, remat_input) + dup.name = dep.name + "_recomputed" + recomputed_nodes[dep] = dup + + env[node] = new_graph.node_copy(node, remat_input) + + new_gm = torch.fx.GraphModule(gm, new_graph) + + # DCE with custom is_impure_node (like default_partition) + # Treats certain collectives as pure while delegating to default impurity logic + new_gm.graph.eliminate_dead_code(is_impure_node=is_impure_node_for_dce) + + # raise_getitems pass for better memory (like default_partition) + new_gm = raise_getitems(new_gm) + + new_gm.recompile() + + return new_gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/aot_autograd_result.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/aot_autograd_result.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbacfaf3080264bdb538ab96d22b71b6f64b12e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -0,0 +1,676 @@ +# mypy: allow-untyped-defs +""" +This module provides result classes for AOT Autograd compilation. + +Similar to how torch._inductor.output_code provides OutputCode classes for inductor +compilation results, this module provides AOTAutogradResult classes that represent +the compiled artifacts produced by AOT Autograd. + +These results are: +- Serializable: can be saved/loaded from disk without recompilation +- Addressable: can be stored in caches with keys for later retrieval +- Reusable: can be used for both caching and ahead-of-time compilation (precompile) + +The main result types are: +- GenericAOTAutogradResult: Abstract base for all AOT Autograd results +- AOTAutogradResult: Regular result that references FxGraphCache entries +- BundledAOTAutogradResult: Result that bundles the entire compiled code directly +""" + +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from copy import copy +from dataclasses import dataclass +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar + +import torch +from torch._dynamo.precompile_context import BackendCacheArtifact +from torch._inductor.codecache import FxGraphCache +from torch._inductor.output_code import ( + CompiledFxGraph, + CompiledFxGraphConstants, + OutputCode, +) +from torch._inductor.utils import should_use_remote_fx_graph_cache + +from .runtime_wrappers import ( + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + CachedAutogradLazyBackwardCompileInfo, + CompilerWrapper, + FunctionalizedRngRuntimeWrapper, + post_compile, + RuntimeWrapper, + SerializableCompiledFunction, + SubclassMeta, +) +from .schemas import AOTAutogradCacheInfo # noqa: F401 +from .utils import simple_wraps + + +if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs + + from .schemas import AOTConfig, ViewAndMutationMeta + +log = logging.getLogger(__name__) + + +TOut = TypeVar("TOut", bound=OutputCode) + + +class InductorOutput(ABC, Generic[TOut]): + """ + Class representing a single inductor output + """ + + @abstractmethod + def pre_save(self) -> None: ... + + @abstractmethod + def load(self, example_inputs) -> TOut: ... + + @abstractmethod + def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ... + + +TOutputCode = TypeVar("TOutputCode", bound=OutputCode) + + +@dataclass +class BundledOutputCodeLoadable(InductorOutput[TOutputCode], Generic[TOutputCode]): + """ + A generic wrapper for OutputCode objects that are bundled directly in the cache + (rather than looked up via FxGraphCache). + + This works for any OutputCode subclass (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + result: TOutputCode + + def pre_save(self) -> None: + disk_result = copy(self.result) + disk_result.prepare_for_serialization() + self.result = disk_result + return + + def load(self, example_inputs) -> TOutputCode: + self.example_inputs = example_inputs + return self.result + + def post_compile( + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: + constants = CompiledFxGraphConstants() + + # Special handling for CompiledFxGraph - needs FxGraphCache.cache_hit_post_compile + if isinstance(result, CompiledFxGraph): + graph, cache_info = FxGraphCache.cache_hit_post_compile( + result, {}, constants + ) + if graph is None: + raise RuntimeError("Failed to reload cache entry from disk") + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_bundled_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + result = graph # type: ignore[assignment] + + # Run normal post compile + result.post_compile(self.example_inputs, constants, fx_config) + return result + + +# Backwards compatibility alias +CompiledFxGraphLoadable: type[BundledOutputCodeLoadable[CompiledFxGraph]] = ( + BundledOutputCodeLoadable[CompiledFxGraph] +) + + +@dataclass +class FxGraphCacheLoadable(InductorOutput[CompiledFxGraph]): + fx_graph_cache_info: tuple[str, list[str]] + fx_graph_guard_expr: Optional[str] + + def pre_save(self): + return + + def _is_backward(self) -> bool: + return False + + def load(self, example_inputs) -> CompiledFxGraph: + from .autograd_cache import FXGraphCacheMiss + + # [Note: AOTAutogradCache and FXGraphCache Guard interactions] + # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. + # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. + # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly + # the same as the ones it passes to inductor, for both the forward and backward passes. + # (This does not mean that the tensor values passed in are the same: only that their symints are). + # That is, AOTAutograd and Inductor never create new guards based on symints with different sources + # than those passed to it by inductor. + # We pass the post compile function, which sets various fx_config boxed values, + # so we can call it only after we're sure both forward and backward have + # Clear CompiledTritonKernels before loading from FXGraphCache + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + remote_cache = None + constants = CompiledFxGraphConstants() + if should_use_remote_fx_graph_cache(): + remote_cache = FxGraphCache.get_remote_cache() + (cache_key, debug_lines) = self.fx_graph_cache_info + + def check_exact_guard_match(guard_expr, _hints): + """ + AOTAutogradCache tracks its own guards, so we just need to treat these guard expressions as a second + cache key of sorts: we just check for equality, i.e. the FXGraphCache entry with + the exact same guards as we originally saved into the cache. + """ + return guard_expr == self.fx_graph_guard_expr + + result, cache_info = FxGraphCache.load_with_key( + cache_key, + debug_lines, + example_inputs, + local=True, + remote_cache=remote_cache, + is_backward=self._is_backward(), + constants=constants, + evaluate_guards=check_exact_guard_match, + ) + if result is None: + log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_info) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_miss", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + + raise FXGraphCacheMiss + + # No need to log chromium event because AOTAutograd will log that immediately for us + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + self.example_inputs = example_inputs + self.constants = constants + return result + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + """ + Called after FXGraphCacheLoadable.load, mutates fx_config + """ + result.post_compile(self.example_inputs, self.constants, fx_config) + return result + + +@dataclass +class CompiledForward(FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return False + + +@dataclass +class GenericCompiledBackward(InductorOutput[TOut]): + # Used by AOTDispatchAutograd.post_compile + backward_state_indices: list[int] + num_symints_saved_for_bw_: int + + +@dataclass +class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return True + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + + +# Generic bundled forward/backward classes that work with any OutputCode type +@dataclass +class BundledCompiledForward( + BundledOutputCodeLoadable[TOutputCode], Generic[TOutputCode] +): + """ + Generic forward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + +@dataclass +class BundledCompiledBackward( + GenericCompiledBackward[TOutputCode], + BundledOutputCodeLoadable[TOutputCode], + Generic[TOutputCode], +): + """ + Generic backward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + def post_compile( + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + + +@dataclass +class SerializedGraphModule: + fn: Callable[[dict[Any, Any], str], torch.nn.Module] + args: tuple[Any, ...] + + def __init__(self, gm: torch.fx.GraphModule): + self.fn, self.args = gm.__reduce__() + + def deserialize(self) -> torch.fx.GraphModule: + gm = self.fn(*self.args) + assert isinstance(gm, torch.fx.GraphModule) + return gm + + +def serialize_graph_module(gm: torch.fx.GraphModule) -> SerializedGraphModule: + # NOTE: mutates the graph module + gm.meta = {} + for node in gm.graph.nodes: + node.meta = {} + return SerializedGraphModule(gm) + + +TForward = TypeVar("TForward", bound=InductorOutput) +TBackward = TypeVar("TBackward", bound=GenericCompiledBackward) + + +@dataclass +class GenericAOTAutogradResult(Generic[TForward, TBackward]): + """A single result from AOT Autograd compilation, genericized by Forward and Backward types. + + A TForward is always an InductorOutput of some sort, which represents the + forward graph of the compile. + A TBackward is an InductorOutput + metadata about the backward, useful for specific + backward-only wrappers. This type is encapsulated by GenericCompiledBackward. + + Each AOTAutogradResult is essentially parameterized by 1. the method of loading + from the cache (either Bundled or UnBundled), and 2. The type of the output. For now, + the only type of output we support is Python Wrapper output, i.e. OutputCode.CompiledFxGraph, + but the same technique works for C++ wrapper code; we'd just add an extra InductorOutput type. + """ + + # Forward and Backward info + compiled_fw: TForward + compiled_bw: Optional[TBackward] + + # Code of the joint graph using print_readable() + # Used for logging purposes + aot_joint_graph_str: Optional[str] + aot_forward_graph_str: Optional[str] + aot_backward_graph_str: Optional[str] + + # Runtime_metadata saved right before compilation + runtime_metadata: ViewAndMutationMeta + + # Wrappers that run after each aot_dispatch_* function + dispatch_wrappers: list[CompilerWrapper] + + # Used by AOTSubclassWrapper + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + # Used by RuntimeWrapper + indices_of_inps_to_detach: list[int] + + # Time taken to trace/compile the forward + # forward_time_taken includes AOTAutograd tracing time + inductor compilation time + # backward_time_taken is essentially just the time inductor took to compile + forward_time_taken_ns: int + backward_time_taken_ns: int + + # Used by standalone_compile + sanitized_aot_config: AOTConfig + + guards_expr: Optional[str] + + # Used by Compiled Autograd + serialized_bw_module: Optional[SerializedGraphModule] + + def pre_save(self): + """ + Perform any preparations to make the result ready for serialization. + """ + self.compiled_fw.pre_save() + if self.compiled_bw is not None: + self.compiled_bw.pre_save() + + # Turn result into the original callable + def wrap_post_compile( + self, + args: list[torch.Tensor], + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ) -> Callable: + """ + This function takes a result and carefully reconstructs the original callable + that AOTAutograd returned the first time it was run. It does this by running the various + post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. + + In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. + In the autograd path, this consists of AOTAutogradDispatch.post_compile. + + The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. + + Notably absent from the cached path are: + - DebugAssertWrapper + - FakifiedOutWrapper + + Which we'll handle separately later on, if necessary. + """ + from torch._dynamo.utils import CompileEventLogger, dynamo_timed + + # Log the output of AOTAutogradCache + if aot_config.enable_log: + # TODO: maybe also log to aot_graphs_log + # Unfortunately aot_graphs_log uses + # slightly different formatting though + if self.aot_joint_graph_str is not None: + torch._logging.trace_structured( + "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str + ) + + if self.aot_forward_graph_str is not None: + from torchgen.utils import dataclass_repr + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.runtime_metadata), + ) + if self.maybe_subclass_meta is not None: + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.maybe_subclass_meta), + ) + + # It's called an inference graph if not running with autograd + name = ( + "aot_forward_graph" + if self.aot_backward_graph_str is not None + else "aot_inference_graph" + ) + torch._logging.trace_structured( + name, payload_fn=lambda: self.aot_forward_graph_str + ) + + if self.aot_backward_graph_str is not None: + torch._logging.trace_structured( + "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str + ) + with dynamo_timed("AOTAutogradCache.inductor_load"): + compiled_fw_func = self.compiled_fw.load(args) + compiled_bw_func = None + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + # Now that we've loaded forward and backward, call post compile on both + # This avoids setting things like BoxedBools in fx_config until + # after both forward and backward cache hit + fw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + bw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": True, + } + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, fw_fx_config + ) + compiled_bw_func = self.compiled_bw.post_compile( + compiled_bw_func, bw_fx_config + ) + else: + inference_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + + needs_autograd = False + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, inference_fx_config + ) + + # Wrap the forward function in post compile wrappers + compiled_fw_func = AOTDispatchSubclassWrapper( + trace_joint=needs_autograd, + fw_only=None, + maybe_subclass_meta=self.maybe_subclass_meta, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + req_subclass_dispatch = self.maybe_subclass_meta is not None + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + # In autograd case, functionalizedRngWrapper should not modify outs + return_new_outs = not needs_autograd + compiled_fw_func = FunctionalizedRngRuntimeWrapper( + return_new_outs=return_new_outs + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + compiled_fw_func._boxed_call = True + disable_amp = torch._C._is_any_autocast_enabled() + + if needs_autograd: + assert self.compiled_bw is not None + + cached_lazy_backward = None + if self.serialized_bw_module is not None: + cached_lazy_backward = CachedAutogradLazyBackwardCompileInfo( + self.serialized_bw_module.deserialize + ) + # This function is run on both cache miss and cache hit, either here + # or in aot_dispatch_autograd. On a cache hit, + # 1. the bw is already compiled + # 2. we don't need to save to the cache again + # so those corresponding arguments are set to None. + compiled_function = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + self.maybe_subclass_meta, + self.compiled_bw.num_symints_saved_for_bw_, + self.compiled_bw.backward_state_indices, + disable_amp, + self.indices_of_inps_to_detach, + cached_lazy_backward, + aot_config, + fw_metadata=self.runtime_metadata, + try_save_cache_entry=None, + ) + + else: + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + # Add serialization function back onto object + compiled_function, _ = post_compile( + self.dispatch_wrappers, + compiled_function, + aot_config, + runtime_metadata=self.runtime_metadata, + ) + + # Now that we're pretty sure it's a successful load, add guards + # to the existing shape environment from the cache + if self.guards_expr: + from .autograd_cache import AOTAutogradCache + + symints = AOTAutogradCache._filter_backed_symints(args) + check = bool(AOTAutogradCache.evaluate_guards(self.guards_expr, symints)) + assert check is True + + return compiled_function + + +class AOTAutogradResult(GenericAOTAutogradResult[CompiledForward, CompiledBackward]): + """ + Regular AOTAutogradResult: saves the forward/backward FxGraphCache keys + and looks them up in FxGraphCache on load + """ + + +class BundledAOTAutogradResult( + GenericAOTAutogradResult[ + BundledCompiledForward[TOutputCode], BundledCompiledBackward[TOutputCode] + ], + Generic[TOutputCode], +): + """ + Generic AOTAutogradResult where we bundle the entire OutputCode directly + (rather than looking it up via FxGraphCache). + + This works with any OutputCode type: + - CompiledFxGraph: Traditional inductor compilation + - RegionalOutputCode: Regional inductor compilation with GraphPickler serialization + - Any future OutputCode subclasses + + Type parameter: + TOutputCode: The OutputCode subclass (e.g., CompiledFxGraph, RegionalOutputCode) + + Usage with CompiledFxGraph: + entry = BundledAOTAutogradResult[CompiledFxGraph]( + compiled_fw=BundledCompiledForward(result=CompiledFxGraph(...)), + compiled_bw=BundledCompiledBackward( + result=CompiledFxGraph(...), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) + + Usage with RegionalOutputCode: + entry = BundledAOTAutogradResult[RegionalOutputCode]( + compiled_fw=BundledCompiledForward(result=RegionalOutputCode(gm)), + compiled_bw=BundledCompiledBackward( + result=RegionalOutputCode(gm), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) + """ + + +def deserialize_bundled_cache_entry(entry: BundledAOTAutogradResult) -> Callable: + from copy import deepcopy + + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.utils import BoxedBool + + # In the precompile use case, guards are already serialized + # by dynamo, so we don't need to add them to the environment + entry.guards_expr = None + # TODO: this isn't exactly right, because cudagraphs needs to be a shared config + # which is set by compile_fx. But in precompile, we never actually call compile_fx + # so we don't have a place to track cudagraphs here. + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + boxed_forward_device_index = BoxedDeviceIndex(None) + # We need to make a clean copy of the cache entry + # in case it needs to be serialized again + serializable_copy = deepcopy(entry) + + from torch._subclasses import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + context = torch._guards.TracingContext.try_get() + if context is None: + # Create a clean environment when running fx graph post compile + # if one is not available + context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv())) + with torch._guards.tracing(context): + compiled_fn = entry.wrap_post_compile( + [], + entry.sanitized_aot_config, + { + "cudagraphs": cudagraphs, + "boxed_forward_device_index": boxed_forward_device_index, + }, + ) + # Ensure the deserialized cache entry is still serializable + + compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: serializable_copy) + + # TODO: this ignores flat_params, which can exist + # if inline_builtin_nn_modules=False + @simple_wraps(compiled_fn) + def forward(*runtime_args: tuple[Any]): + return compiled_fn(list(runtime_args)) + + assert hasattr(compiled_fn, "serialize") + forward.serialize = compiled_fn.serialize # type: ignore[attr-defined] + + return forward + + +@dataclass +class BundledAOTAutogradCacheArtifact(BackendCacheArtifact[Callable]): + def after_deserialization(self) -> Callable: + return deserialize_bundled_cache_entry(self.content) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7b4c8973c5df21187350e50ed5b40c18860cc4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py @@ -0,0 +1,1037 @@ +# mypy: allow-untyped-defs +""" +Utils for caching the outputs of AOTAutograd +""" + +from __future__ import annotations + +import base64 +import contextlib +import functools +import json +import logging +import os +import pickle +import random +import shutil +import time +import traceback +from copy import copy +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import override + +import torch +from torch._dynamo.precompile_context import PrecompileContext +from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions +from torch._dynamo.utils import chromium_event_log_active, CompileEventLogger, counters +from torch._functorch import config +from torch._inductor.codecache import ( + _ident, + add_ephemeral_timeout_increase_for_distributed, + BypassFxGraphCache, + create_cache, + extract_tensor_metadata_for_cache_key, + FxGraphCache, + FxGraphCachePickler, + FxGraphHashDetails, + GuardedCache, + sha256_hash, + write_atomic, +) +from torch._inductor.output_code import OutputCode +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import BoxedBool, should_use_remote_fx_graph_cache +from torch._logging import LazyString +from torch._utils_internal import log_cache_bypass +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.fx.experimental.symbolic_shapes import hint_int +from torch.utils._triton import has_triton_package + +from .aot_autograd_result import ( + AOTAutogradResult, + BundledAOTAutogradCacheArtifact, + BundledAOTAutogradResult, + BundledCompiledBackward, + BundledCompiledForward, + CompiledBackward, + CompiledForward, + GenericAOTAutogradResult, + SerializedGraphModule, +) +from .runtime_wrappers import ( + CompilerWrapper, + SerializableCompiledFunction, + SubclassMeta, +) +from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta # noqa: F401 + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch._inductor.compile_fx import _CompileFxKwargs + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.remote_cache import JsonDataTy, RemoteCache + from torch.fx.node import Node + + +log = logging.getLogger(__name__) + + +class BypassAOTAutogradCache(Exception): + pass + + +# Used to signify when FXGraphCache missed when AOTAutogradCache uses it +class FXGraphCacheMiss(BypassAOTAutogradCache): + pass + + +def should_use_remote_autograd_cache(): + if torch.compiler.config.force_disable_caches: + return False + if config.enable_remote_autograd_cache is not None: + return config.enable_remote_autograd_cache + if not config.is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk_name = "pytorch/remote_cache:aot_autograd_cache_version" + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name) + + +def should_use_local_autograd_cache(): + if torch.compiler.config.force_disable_caches: + return False + return config.enable_autograd_cache + + +def should_bundle_autograd_cache(): + return config.bundled_autograd_cache or torch._dynamo.config.caching_precompile + + +def check_node_safe(node: Node): + """ + Checks that the node only uses supported operators. We are starting with very + conservative cacheability constraints, and incrementally adding more support as we expand. + + [Note: AOTAutograd Cacheability checks] + - Our cache key is computed from the FX graph produced by Dynamo and the input example values + - A node is "safe" if the same cache key results in a compiled artifact that has the same behavior + (i.e, the set of inputs that go into our cache key is sufficient to distinguish its behavior) + + To accomplish this safety check, we consider the following functions to be safe: + - Public functions under modules torch, torch.functional, and torch.nn.functional: these are + allowed in the graph by dynamo, so we can assume they are safe to cache. + - method calls on base tensor types + - Any call_module that dynamo deemed safe to allow AOTAutograd to trace + - Non callable nodes, such as placeholder, output, get_attr + + The test suite test_aot_autograd_cache.py::AOTAutogradCachePicklerTests tries its best to fully cover/specify this behavior. + """ + SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional") + SAFE_TORCH_FUNCTIONS = ( + "torch.Size", + "torch.Tensor", + "torch.sym_int", + "torch._sym_sqrt", + "torch.sym_float", + "torch.sym_sum", + ) + SAFE_NON_TORCH_FUNCTIONS = ( + "einops.einops.rearrange", + "einops.einops.repeat", + ) + + def is_public_torch_api(target): + # Don't blindly allow private functions in the torch namespace + is_private = target.__name__.startswith("_") + + return ( + getattr(target, "__module__", None) in SAFE_TORCH_MODULES and not is_private + ) + + def is_safe_torch_function(target): + """Allowlisted torch functions""" + function_name = f"{target.__module__}.{target.__name__}" + # Allow torch.autograd.function.FunctionCtx if custom autograd functions are allowed + if function_name == "torch.autograd.function.FunctionCtx": + return ( + torch._functorch.config.autograd_cache_allow_custom_autograd_functions + ) + + # Functions in torch_non_c_binding_in_graph_functions + # are guaranteed to be cache safe. + # See NOTE: [Cacheability of in-graph torch functions] + return ( + function_name in torch_non_c_binding_in_graph_functions + or function_name in SAFE_TORCH_FUNCTIONS + or function_name in torch._inductor.config.unsafe_marked_cacheable_functions + ) + + def is_cacheable_function(target): + if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return True + if is_public_torch_api(target): + return True + # Technically, FXGraphCache._check_for_hop already checks this, + # but better to error earlier anyway + if isinstance(target, torch._ops.HigherOrderOperator): + return target.cacheable() + is_builtin_fun_or_type = type(target).__name__ == "builtin_function_or_method" + if is_builtin_fun_or_type: + return True + if is_safe_torch_function(target): + return True + function_name = f"{target.__module__}.{target.__name__}" + if function_name in SAFE_NON_TORCH_FUNCTIONS: + return True + return False + + def is_tensor(target): + # Tensors always have example values in meta field + return "example_value" in target.meta + + # I'd love to use a match statement here, but it wasn't introduced until py3.10 + if node.op == "call_function": + if node.meta and node.meta.get("is_wrapped", False): + # This is fx.wrap function + # By default we BypassAOTAutogradCache for unknown functions, + # But if user explicitly specified cache hash - allow to cache it. + if node.meta.get("user_cache_hash", None): + return + + if not is_cacheable_function(node.target): + module = getattr(node.target, "__module__", None) + name = getattr(node.target, "__name__", None) + raise BypassAOTAutogradCache( + f"Unsupported call_function target {node.target}. \n Function module: {module}, \nFunction name: {name}" + ) + elif node.op == "call_method": + method_name = node.target + method_target = node.args[0] + # Only support method calls on base tensors + if not is_tensor(method_target): + module = getattr(method_target, "__module__", None) + name = getattr(method_target, "__name__", None) + raise BypassAOTAutogradCache( + f"Unsupported call_method target {method_target}. \nMethod module: {module}, \nMethod name: {name}" + ) + if ( + type(method_name) is not str + and type(method_name).__name__ != "method_descriptor" + ): + raise BypassAOTAutogradCache( + f"Unsupported call_method method {node.target}: {method_name}" + ) + # Cache safe + elif node.op in ("placeholder", "get_attr", "call_module", "output"): + # Assumption today for call_module being a safe op: + # (1) today the only call_module ops that can show up in a graph come from "built-in-nn-modules" + # that dynamo assumes are safe to trace. If dynamo assumes they are safely to blindly trace, then + # they should be safe to cache as well. + # (2) in the steady-state (some time in H2?) we shouldn't see these anymore, once inline builtin nn modules by default + # (3) We do not allow user made nn modules in the graph today, only function calls. + pass + else: + raise BypassAOTAutogradCache(f"Unsupported node op {node.op}") + + +def check_cacheable(gm: torch.fx.GraphModule): + """ + Checks that the graph module only uses supported operators + """ + nodes = gm.graph.nodes + if torch._inductor.config.freezing: + raise BypassAOTAutogradCache("Cannot cache a graph with freezing enabled") + + if not ( + torch._inductor.config.fx_graph_cache or should_use_remote_fx_graph_cache() + ): + raise BypassAOTAutogradCache("FX graph cache is not enabled") + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context and tracing_context.fakify_first_call: + raise BypassAOTAutogradCache( + "Won't cache a graph with fakify_first_call enabled" + ) + for node in nodes: + check_node_safe(node) + + # Saved tensors hooks are globally set subgraphs, + # that are not used explicitly in the main graph. + # They are inlined in aot_autograd graphs. + # Subgraphs are only used for caching logic. + if hasattr(gm, "saved_tensors_hooks_pack_0"): + check_cacheable(gm.saved_tensors_hooks_pack_0) # type: ignore[arg-type] + # We have guarantee of unpack sugraph existence if pack subgraph exists + check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type] + + +class AOTAutogradCacheDetails(FxGraphHashDetails): + """ + Object to capture all the details for a dynamo graph module relevant to computing + a safe and stable cache key for AOTAutograd. + """ + + def get_triton_source_codes_from_gm( + self, + gm: torch.fx.GraphModule, + ): + assert has_triton_package(), "Triton is not available" + + triton_kernels = [] + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if isinstance(node.target, torch._ops.OpOverloadPacket): + attrs = node.target._dir + for attr in attrs: + if custom_op := getattr(node.target, attr, None): + kernels = torch._library.triton.get_triton_kernels_for_op( + custom_op._name + ) + triton_kernels.extend(kernels) + elif isinstance(node.target, torch._ops.OpOverload): + kernels = torch._library.triton.get_triton_kernels_for_op( + node.target._name + ) + triton_kernels.extend(kernels) + + triton_kernel_source_codes = [] + from torch._inductor.codegen.wrapper import ( + user_defined_triton_kernel_transitive_closure_source_code, + ) + + for kernel in triton_kernels: + from triton.runtime.autotuner import Autotuner + + if isinstance(kernel, Autotuner): + # Grab the Inner JITFunction + kernel = kernel.fn + source_codes = user_defined_triton_kernel_transitive_closure_source_code( + kernel + ) + triton_kernel_source_codes.append(source_codes) + + return triton_kernel_source_codes + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs, + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ): + # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info + self.aot_config = aot_config + self.grad_enabled = torch.is_grad_enabled() + self.disable_amp = torch._C._is_any_autocast_enabled() + self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled() + self.autograd_config = config.save_config() + self.saved_tensors_hooks_fx_wrap_cache_hashes: tuple[list[str], list[str]] = ( + [], + [], + ) + if has_triton_package(): + self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm) + + if hasattr(gm, "saved_tensors_hooks_pack_0"): + + def _add_wrapped_user_cache_hashes(_gm, _l): + for node in _gm.graph.nodes: + if node.meta and node.meta.get("is_wrapped", False): + _l.append(node.meta["user_cache_hash"]) + + _add_wrapped_user_cache_hashes( + gm.saved_tensors_hooks_pack_0, + self.saved_tensors_hooks_fx_wrap_cache_hashes[0], + ) + _add_wrapped_user_cache_hashes( + gm.saved_tensors_hooks_unpack_0, + self.saved_tensors_hooks_fx_wrap_cache_hashes[1], + ) + + try: + # FXGraphCache has constraints on what can be pickled in its inductor + # config. Check that the gm is cacheable by inductor first, + # and if it raises an exception, also bypass on our end. + FxGraphCache._check_can_cache(gm) + super().__init__(gm, example_inputs, fx_config, []) + except BypassFxGraphCache as e: + # Sometimes inductor configs are unpickleable and can fail + raise BypassAOTAutogradCache(str(e)) from e + + +class AOTAutogradCachePickler(FxGraphCachePickler): + def __init__(self, gm: torch.fx.GraphModule): + super().__init__(gm) + # pyrefly: ignore [bad-override] + self.dispatch_table: dict + self.dispatch_table.update( + { + AOTConfig: functools.partial(self._reduce_aot_config), + torch.Tensor: functools.partial(self._reduce_tensor), + } + ) + + def _reduce_aot_config(self, aot_config: AOTConfig): + """ + Reduce the config to a stable key for caching. + """ + return ( + _ident, + ( + aot_config.num_params_buffers, + aot_config.keep_inference_input_mutations, + aot_config.is_export, + aot_config.no_tangents, + aot_config.dynamic_shapes, + aot_config.aot_autograd_arg_pos_to_source, + aot_config.enable_log, + aot_config.pre_dispatch, + ), + ) + + def _reduce_tensor(self, tensor): + """ + Reduce the tensor to a stable key for caching. + """ + metadata = extract_tensor_metadata_for_cache_key(tensor) + return (_ident, (metadata,)) + + +@contextlib.contextmanager +def normalize_placeholder_names(gm: torch.fx.GraphModule): + """ + Context manager that normalizes the placeholder names in the graph module. + This is used while generating a cache key for AOTAutogradCache, so that two graphs + that are isomorphic when normalizing names can hit the same cache entry. + This is safe because nothing underneath AOTAutograd uses the node names on the + original dynamo graph: AOTAutograd re-traces with its own nodes, and guards are + in terms of original sources rather than placeholder names. + """ + # Standalone inductor: we're bypassing AOTAutogradCache anyway, so return the graph + # as-is + if not config.autograd_cache_normalize_inputs or not hasattr(gm, "graph"): + yield + return + + # Track all the old state of placeholders + old_placeholder_names = [] + old_used_names = copy(gm.graph._graph_namespace._used_names) + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + # _rename renames the node in the body of the function, + # but it doesn't change the raw name from node.target + # So we also set the raw_name of node.target to a new placeholder name + new_placeholder_name = f"p_{i}" + old_placeholder_names.append((n.name, n.target)) + n.target = new_placeholder_name + n._rename(new_placeholder_name) + i += 1 + gm.recompile() + try: + yield + finally: + # Used_names contains all our old placeholder names, + # so we clear it temporarily when we put them back + gm.graph._graph_namespace._used_names = set() + # Restore the placeholder names + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + (name, target) = old_placeholder_names[i] + n.target = target + n._rename(name) + i += 1 + assert i == len(old_placeholder_names) + # Now restore the old namespace's used names + gm.graph._graph_namespace._used_names = old_used_names + gm.recompile() + + +def autograd_cache_key( + gm: torch.fx.GraphModule, + example_inputs, + config: AOTConfig, + fx_config: _CompileFxKwargs, + # TODO: add args and parameters +) -> tuple[str, list[str]]: + """ + Generate a unique hash of the FX graph for caching. + """ + + try: + check_cacheable(gm) + if has_triton_package(): + # Due to https://github.com/triton-lang/triton/issues/3729, + # if triton is < 3.2.0, AOTAutogradCache may cause us to + # attempt to load a cache entry without initializing + # the CUDA context on the autograd thread. + + # Without caching, we naturally do this initialization when + # tracing through the graph with the autograd engine. + import triton + + if triton.__version__ < "3.2.0": + raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") + details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) + pickler = AOTAutogradCachePickler(gm) + # The prefix distinguishes among the other kinds of objects we cache + key = "a" + pickler.get_hash(details) + debug_lines = pickler.debug_lines(details) + log.debug( + "Autograd graph cache hash details for key %s:\n%s", + key, + LazyString(lambda: "\n".join(debug_lines)), + ) + return key, debug_lines + except Exception: + # If enable_aot_compile is set, we're in AOT precompile mode where we always + # want to use fallback nonce keys. Unlike caching, it's fine if we can't generate + # a proper key because we are guaranteed in an AOT precompile world users are in + # complete control of distributing and loading artifacts. + if torch._dynamo.config.enable_aot_compile: + log.info( + "Failed to generate AOTAutograd cache key; falling back to nonce due to enable_aot_compile", + exc_info=True, + ) + return str(random.random()), [] + else: + raise + + +@contextlib.contextmanager +def sanitize_gm_for_cache(gm: torch.fx.GraphModule): + """ + Clears a few fields in a dynamo supplied Graph Module that are not stable between graph inputs, but don't + affect inductor or aotdispatch correctness. + + These fields **can** be used by code calling into aotdispatch (namely, dynamo), so we can't null them out completely. + + To ensure that these fields are not accessed by inductor or aotdispatch, we clear them during AOTAutogradCache.load, + and then put them back before returning. This way, we generate a cache key based off of a canonical graph + without these fields, and also guarantee they aren't used to affect the cache's output. + """ + # Mapping from each field to a default value + IGNORED_FIELDS: dict[str, Any] = { + "meta": {}, # metadata used by export + "compile_subgraph_reason": None, # Used by dynamo only for logging, no change in inductor/autograd behavior + "_param_name_to_source": None, # Encapsulated by aot_config.aot_autograd_arg_pos_to_source + "_backend_id": None, + } + saved_fields = {} + for field, default_value in IGNORED_FIELDS.items(): + saved_fields[field] = getattr(gm, field, None) + # Clear the field + setattr(gm, field, default_value) + try: + with normalize_placeholder_names(gm): + yield + finally: + for field, value in saved_fields.items(): + setattr(gm, field, value) + + +@CacheArtifactFactory.register +class AOTAutogradCacheArtifact(CacheArtifact): + @override + def populate_cache(self): + AOTAutogradCache._write_to_local_cache(self.key, self.content) + + @override + @staticmethod + def type(): + return "aot_autograd" + + +class AOTAutogradCache(GuardedCache[GenericAOTAutogradResult]): + """ + Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas + AOTAutogradResult handles the wrapping/unwrapping logic. + + Cache Inputs (AOTAutogradCacheDetails) + - AOTAutogradCache takes in the following inputs, which are analogous to inputs given + to AOTAutograd by dynamo: + - A fx graph module generated by dynamo + - A list of args, which consists of: + - Symint inputs to the graph, generated by dynamo + - The **real tensor** inputs, which inductor uses for cudagraphs + - Notably, the real tensor inputs don't have symints in their metadata. + AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution. + - A set of global configurations that affect AOTAutograd or Inductor behavior. + + It then generates a cache key given these values. Notably, this means AOTAutogradCache currently + specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on. + In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates + based on the real tensor inputs, which can contain symints. + + # Cache Outputs (AOTAutogradResult) + - AOTAutogradCache caches the following values: + - The compiled forward and backward functions from inductor, via keys to the FXGraphCache + - Metadata to reconstruct the AOTModule from the compiled inductor artifacts + - See AOTAutogradResult for more info + + [Note: Caching guards generated by AOTAutograd and Inductor] + AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each + compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions + from FXGraphCache, giving it new symint arguments from the input args. + FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards. + **No new guards are generated into the shape env after inductor finishes compiling**, so the guards + saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches. + """ + + @staticmethod + def clear(): + """Clear the cache""" + try: + shutil.rmtree(AOTAutogradCache._get_tmp_dir()) + except FileNotFoundError: + pass + + @staticmethod + def try_load( + mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper], + args, + aot_config: AOTConfig, + cudagraphs: BoxedBool, + boxed_forward_device_index: Optional[BoxedDeviceIndex], + local: bool, + remote: bool, + ) -> Optional[Callable]: + """ + Load a result from the cache, and reconstruct a runtime wrapper around the object + """ + gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod + with sanitize_gm_for_cache(gm): + compiled_fn = None + cache_info: dict[str, Any] = {} + cache_key = None + debug_lines: list[str] = [] + cache_event_time = time.time_ns() + cache_state = None + fx_config: _CompileFxKwargs = { + "cudagraphs": cudagraphs, + "boxed_forward_device_index": boxed_forward_device_index, + } + try: + cache_key, debug_lines = autograd_cache_key( + gm, args, aot_config, fx_config + ) + result: Optional[tuple[GenericAOTAutogradResult, bytes]] = ( + AOTAutogradCache._lookup( + cache_key, local, remote, args, cache_info, aot_config + ) + ) + if result is not None: + (entry, pickled_content) = result + compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) + # Make the compiled_fn serializable, where the serialize function just + # makes a copy of the original entry before post compile via the pickled content + compiled_fn = SerializableCompiledFunction( + compiled_fn, lambda: pickle.loads(pickled_content) + ) + log.info("AOTAutograd cache hit for key %s", cache_key) + + counters["aot_autograd"]["autograd_cache_hit"] += 1 + cache_state = "hit" + cache_event_time = time.time_ns() + forward_time_saved = entry.forward_time_taken_ns // 1e6 + backward_time_saved = entry.backward_time_taken_ns // 1e6 + cache_info.update( + { + "forward_time_saved_ms": forward_time_saved, + "backward_time_saved_ms": backward_time_saved, + "time_saved_ms": forward_time_saved + backward_time_saved, + } + ) + time_saved_ns = ( + entry.forward_time_taken_ns + entry.backward_time_taken_ns + ) + # TODO: should we use the same field for remote cache time saved for both + # FXGraphCache and AOTAutogradCache? + # get_metrics_context().increment(...) + if ( + ephemeral_increase + := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + + if compiled_fn is None: + log.info("AOTAutograd cache miss for key %s", cache_key) + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + cache_event_time = time.time_ns() + # Count missing the FXGraphCache as a miss not a bypass + except FXGraphCacheMiss as e: + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + if ( + config.strict_autograd_cache + or torch._dynamo.config.strict_precompile + ): + raise e + # Most often this is BypassAOTAutogradCache, but + # if there's ever different reason we can't cache, + # we still never want to hard throw an exception, since + # we can always fallback to a cache bypass. + # As an example, if the user calls autograd via + # standalone inductor, we will sometimes get a GraphModule + # that doesn't actually have a `.graph` on it. Instead + # of checking every single case, we safely catch the exception + # in those cases. + except Exception as e: + cache_key = None + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + log.info("Bypassing autograd cache due to: %s", e) # noqa: G200 + cache_state = "bypass" + cache_event_time = time.time_ns() + cache_info["cache_bypass_reason"] = str(e) + cache_info["cache_bypass_exception_type"] = type(e).__name__ + cache_info["cache_bypass_traceback"] = traceback.format_exc().split( + "\n" + ) + # TODO: this gets logged implicitly by cache_bypass_reason, + # and here we explicitly log it into tlparse. + # We may want to log this as an extra column in Scuba, though. + cache_info["cache_bypass_hard_exception"] = not isinstance( + e, BypassAOTAutogradCache + ) + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) + if ( + config.strict_autograd_cache + or torch._dynamo.config.strict_precompile + ): + raise e + if compiled_fn is None: + # Set the cache key so we can save a cache result later + symints = AOTAutogradCache._filter_backed_symints(args) + if cache_key is not None: + aot_config.cache_info = AOTAutogradCacheInfo( + cache_key, + time.time_ns(), + forward_symints=symints, + ) + + cache_info.update( + { + "key": cache_key, + "cache_state": cache_state, + "components": debug_lines, + } + ) + if chromium_event_log_active(): + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", + cache_state=cache_state, + cache_event_time=cache_event_time, + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"aotautograd_cache_{cache_state}", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + + return compiled_fn + + @classmethod + def generate_guards_expression( + cls: type[AOTAutogradCache], cache_info: AOTAutogradCacheInfo + ) -> Optional[str]: + shape_env = cls._get_shape_env() + assert shape_env is not None + symints = cache_info.forward_symints + guards = shape_env.get_pruned_guards(symints) + return shape_env.produce_guards_expression(placeholders=symints, guards=guards) + + @classmethod + def _get_tmp_dir(cls: type[AOTAutogradCache]) -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "aotautograd") + + @classmethod + def _get_tmp_dir_for_key(cls: type[AOTAutogradCache], key) -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cls._get_tmp_dir(), key) + + @staticmethod + def evaluate_guards(guard_expr: str, hints: Union[list[int], list[torch.SymInt]]): + if torch._inductor.config.unsafe_skip_cache_dynamic_shape_guards: + return True + shape_env = AOTAutogradCache._get_shape_env() + assert shape_env is not None + result = shape_env.evaluate_guards_expression(guard_expr, hints) + return result + + @staticmethod + def _lookup( + key: str, + local: bool, + remote: bool, + args: list[Any], + cache_info: dict[str, Any], + aot_config: Optional[AOTConfig], + ) -> Optional[tuple[GenericAOTAutogradResult, bytes]]: + """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" + remote_cache: Optional[RemoteCache[JsonDataTy]] = None + if remote: + remote_cache = AOTAutogradCache.get_remote_cache() + + symints = AOTAutogradCache._filter_backed_symints(args) + hints = [hint_int(s) for s in symints] + entry = None + pickled_content = None + try: + ( + entry, + pickled_content, + guard_info, + ) = AOTAutogradCache.find_guarded_entry( + key, local, remote_cache, AOTAutogradCache.evaluate_guards, hints + ) + + if entry is None and guard_info["cache_status_detailed"] == "guard_miss": + counters["aot_autograd"]["autograd_cache_guard_miss"] += 1 + cache_info.update(guard_info) + if pickled_content is not None: + CacheArtifactManager.record_artifact( + AOTAutogradCacheArtifact.type(), key, pickled_content + ) + if ( + should_bundle_autograd_cache() + and aot_config is not None + and aot_config.precompile_backend_id is not None + ): + # NB: We don't want to use the cached aot_config.precompile_backend_id + # 1. because we set it to None on save 2. even if we didn't, this new run + # that cache hit has a *new* backend id associated with it. + PrecompileContext.record_artifact( + BundledAOTAutogradCacheArtifact( + aot_config.precompile_backend_id, entry + ), + ) + except Exception as e: + log.info("AOTAutograd cache unable to load compiled graph: %s", e) # noqa: G200 + if config.strict_autograd_cache: + raise e + if entry is not None: + assert pickled_content is not None + return (entry, pickled_content) + else: + return None + + @staticmethod + def _write_to_local_cache(key: str, content: bytes): + """Write an entry to the local cache.""" + subdir = AOTAutogradCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized entry to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + log.info("Writing AOTAutograd cache entry to %s", path) + write_atomic(path, content) + + @staticmethod + def save(key: str, entry: GenericAOTAutogradResult, remote: bool): + """Save a single entry into the cache.""" + try: + entry.pre_save() + content = pickle.dumps(entry) + CacheArtifactManager.record_artifact( + AOTAutogradCacheArtifact.type(), key, content + ) + if ( + should_bundle_autograd_cache() + and entry.sanitized_aot_config.precompile_backend_id is not None + ): + precompile_key = entry.sanitized_aot_config.precompile_backend_id + artifact = BundledAOTAutogradCacheArtifact(precompile_key, entry) + # Now that we're saving it, the precompile_backend_id field is no longer + # useful, remove it from the entry. + entry.sanitized_aot_config.precompile_backend_id = None + PrecompileContext.record_artifact(artifact) + AOTAutogradCache._write_to_local_cache(key, content) + counters["aot_autograd"]["autograd_cache_saved"] += 1 + except BypassAOTAutogradCache as e: + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + log.info("Bypassing autograd cache due to: %s", e) # noqa: G200 + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) + return None + except Exception as e: + log.info("AOTAutograd cache unable to serialize compiled graph: %s", e) # noqa: G200 + if remote: + log_cache_bypass( + "bypass_aot_autograd", "Unable to serialize: " + str(e) + ) + if config.strict_autograd_cache: + raise e + return None + + if remote: + remote_cache: Optional[RemoteCache[JsonDataTy]] = ( + AOTAutogradCache.get_remote_cache() + ) + if remote_cache is not None: + time_taken_ms = int( + (entry.forward_time_taken_ns + entry.backward_time_taken_ns) // 1e6 + ) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) + + @staticmethod + @functools.cache + def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + """ + Attempts to load the remote cache, returns None on error. + """ + cache_id = "autograd-experimental" + return create_cache( + cache_id, + config.is_fbcode(), + "FbRemoteAOTAutogradCache", + "RemoteAOTAutogradCache", + ) + + @staticmethod + def make_entry( + compiled_fw_func: OutputCode, + compiled_bw_func: Optional[OutputCode], + aot_joint_graph_str: Optional[str], + aot_forward_graph_str: Optional[str], + aot_backward_graph_str: Optional[str], + runtime_metadata: ViewAndMutationMeta, + dispatch_wrappers: list[CompilerWrapper], + maybe_subclass_meta: Optional[SubclassMeta], + num_fw_outs_saved_for_bw: Optional[int], + indices_of_inps_to_detach: list[int], + forward_time_taken_ns: int, + backward_time_taken_ns: int, + sanitized_aot_config: AOTConfig, + guards_expr: Optional[str], + backward_state_indices: Optional[list[int]], + num_symints_saved_for_bw: Optional[int], + serialized_bw_module: Optional[SerializedGraphModule], + ) -> GenericAOTAutogradResult: + if should_bundle_autograd_cache(): + # Helper function to unwrap all the wrappers we added during aotdispatch + # They get reapplied on cache load + def unwrap_output_code(obj): + while hasattr(obj, "__wrapped__"): + obj = obj.__wrapped__ + assert isinstance(obj, OutputCode) + return obj + + compiled_fw_graph = unwrap_output_code(compiled_fw_func) + bundled_compiled_forward = BundledCompiledForward(compiled_fw_graph) + bundled_compiled_backward = None + if compiled_bw_func is not None: + assert backward_state_indices is not None + assert num_symints_saved_for_bw is not None + compiled_bw_graph = unwrap_output_code(compiled_bw_func) + bundled_compiled_backward = BundledCompiledBackward( + compiled_bw_graph, backward_state_indices, num_symints_saved_for_bw + ) + + return BundledAOTAutogradResult( + compiled_fw=bundled_compiled_forward, + compiled_bw=bundled_compiled_backward, + aot_joint_graph_str=aot_joint_graph_str, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=aot_backward_graph_str, + runtime_metadata=runtime_metadata, + dispatch_wrappers=dispatch_wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + indices_of_inps_to_detach=indices_of_inps_to_detach, + forward_time_taken_ns=forward_time_taken_ns, + backward_time_taken_ns=backward_time_taken_ns, + sanitized_aot_config=sanitized_aot_config, + guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, + ) + + else: + fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) + fw_debug_lines = getattr( + compiled_fw_func, "_fx_graph_cache_debug_lines", [] + ) + + assert fw_key is not None + compiled_forward = CompiledForward( + fx_graph_cache_info=(fw_key, fw_debug_lines), + fx_graph_guard_expr=getattr(compiled_fw_func, "guards_expr", None), + ) + compiled_backward = None + if compiled_bw_func is not None: + bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) + bw_debug_lines = getattr( + compiled_bw_func, "_fx_graph_cache_debug_lines", [] + ) + assert bw_key is not None + assert backward_state_indices is not None + assert num_symints_saved_for_bw is not None + compiled_backward = CompiledBackward( + fx_graph_cache_info=(bw_key, bw_debug_lines), + fx_graph_guard_expr=getattr(compiled_bw_func, "guards_expr", None), + backward_state_indices=backward_state_indices, + num_symints_saved_for_bw_=num_symints_saved_for_bw, + ) + + return AOTAutogradResult( + compiled_fw=compiled_forward, + compiled_bw=compiled_backward, + aot_joint_graph_str=aot_joint_graph_str, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=aot_backward_graph_str, + runtime_metadata=runtime_metadata, + dispatch_wrappers=dispatch_wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + indices_of_inps_to_detach=indices_of_inps_to_detach, + forward_time_taken_ns=forward_time_taken_ns, + backward_time_taken_ns=backward_time_taken_ns, + sanitized_aot_config=sanitized_aot_config, + guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/frontend_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/frontend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..041d321fec56da208dff93ccac9cd85eabd3b4c0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/frontend_utils.py @@ -0,0 +1,336 @@ +# mypy: ignore-errors + +import warnings +from collections.abc import KeysView +from contextlib import contextmanager +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._guards import detect_fake_mode +from torch._library.opaque_object import is_opaque_type +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .descriptors import BufferAOTInput, DifferentiableAOTInput, ParamAOTInput +from .schemas import AOTConfig, FakifiedFlatArgs + + +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +def process_inputs( + flat_args: list[Any], + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], + ignore_shape_env: bool = False, +) -> FakifiedFlatArgs: + with fake_mode: + + def convert(idx, x): + if shape_env is not None and not ignore_shape_env: + from torch._dynamo.source import ConstantSource + + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source, positive=x >= 0), + hint=x, + source=source, + ) + if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)): + return torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, x + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all( + getattr(x, attr).fake_mode is fake_mode for attr in attrs + ) + return x + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + trace = True + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + result = fake_mode.from_tensor( + x, + static_shapes=ignore_shape_env, + symbolic_context=symbolic_context, + source=source, + trace=trace, + ) + return result + + return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + + +def construct_fake_mode( + flat_args: list[Any], aot_config: AOTConfig +) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + return (fake_mode, shape_env) + + +def _try_get_metadata_from_dynamo( + mod: torch.nn.Module, + param_keys: KeysView[str], + full_args_num: int, + full_args_descs: list[DifferentiableAOTInput], +) -> tuple[Optional[list[torch._guards.Source]], list[int]]: + """ + Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. + We first verify that `mod` does come from Dynamo, then we handle cases where + metadata might be missing. + + Returns: + aot_autograd_arg_pos_to_source: used to dedup params and their guards + static_input_indices: used to identify static inputs for cudagraphs + """ + # Note [Assumption on Dynamo Metadata] + # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, + # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. + # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to + # be propagated in order to be recognized as a dynamo graph + + if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): + # graph was not captured by dynamo + return None, [] + + if not hasattr(mod, "_param_name_to_source"): + # is from export + static_input_indices = [ + i + for i, node in enumerate(full_args_descs) + if isinstance(node, (ParamAOTInput, BufferAOTInput)) + ] + return None, static_input_indices + + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + # Additionally, we mark static indices for cudagraphs. + param_name_to_source = mod._param_name_to_source + seen_sources = set() + + aot_autograd_arg_pos_to_source = [] + static_input_indices = [] + # Collect the new inputs lifted by aotdispatch + for i, name in enumerate(param_keys): + assert name in param_name_to_source, f"{name} not found." + source = param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + static_input_indices.append(i) + + # Collect the dynamo graph inputs + # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID + # matched tensors back into the Fx graph, this might not be necessary. + for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + # `source`` specifies the source from user code. ddp optimizer may have + # intermediate values becoming submodule placeholders which does not + # have a source + assert source is None or source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + source_name = source.name if source else str(source) + + # input[i] in dynamo is now: + # input[i + len(extra_params)] in AOT, + # where extra_params are the params/buffers that dynamo baked into the + # OutputGraph + actual_pos = pos + len(param_keys) + + if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( + "_dynamo_static_input_type", None + ): + static_inputs_log.debug( + "Adding static input pos %s for source %s", actual_pos, source_name + ) + static_input_indices.append(actual_pos) + else: + static_inputs_log.debug( + "Non-static input pos %s for source %s", actual_pos, source_name + ) + + assert full_args_num == len(aot_autograd_arg_pos_to_source) + return aot_autograd_arg_pos_to_source, static_input_indices + + +@contextmanager +def _detect_attribute_assignment(mod: torch.nn.Module): + # Do not allow assignment of tensor attributes during export unless + # the attribute is registered as a buffer. + + NN_MODULE_STD_ATTRS = [ + "_backward_hooks", + "_backward_pre_hooks", + "_buffers", + "_forward_hooks", + "_forward_hooks_always_called", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_is_full_backward_hook", + "_load_state_dict_post_hooks", + "_load_state_dict_pre_hooks", + "_modules", + "_non_persistent_buffers_set", + "_parameters", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "training", + ] + NN_MODULE_LAZY_STD_ATTRS = [ + "_initialize_hook", + "_load_hook", + ] + STD_ATTRS = { + *NN_MODULE_STD_ATTRS, + *NN_MODULE_LAZY_STD_ATTRS, + } + + def _get_attributes(mod): + # return any attributes of a module that are not standard attributes + return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} + + def _get_all_module_attributes(mod): + # return attributes from all modules and submodules + result = {} + for name, submodule in mod.named_modules(): + result[name] = _get_attributes(submodule) + return result + + def _restore_all_module_attributes(mod, snapshot): + # restore attributes to all modules and submodules + for name, submodule in mod.named_modules(): + if name in snapshot: + submodule.__dict__.update(snapshot[name]) + + # save state of attributes before enter + snapshot = pytree.tree_map( + lambda x: x, + _get_all_module_attributes(mod), + is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, + ) + try: + yield + finally: + # after exit, compare state of attributes with snapshot + # to detect which tensor attributes were assigned + + def _collect_assigned_tensor_attributes(snapshot, new_attrs): + assigned_tensor_attributes = [] + + def _compare_values(path, old_val, new_val): + """Recursively compare values, handling containers.""" + # Same object, no change + if old_val is new_val: + return + + if old_val is None or new_val is None: + if isinstance(new_val, torch.Tensor): + assigned_tensor_attributes.append(path) + return + + # Check if it's a tensor that was reassigned + if isinstance(new_val, torch.Tensor): + assigned_tensor_attributes.append(path) + return + + # Handle dict containers + if isinstance(old_val, dict) and isinstance(new_val, dict): + all_keys = set(old_val.keys()) | set(new_val.keys()) + for key in all_keys: + old_item = old_val.get(key) + new_item = new_val.get(key) + _compare_values(f"{path}[{key!r}]", old_item, new_item) + return + + # Handle list/tuple containers + if isinstance(old_val, (list, tuple)) and isinstance( + new_val, (list, tuple) + ): + # Different lengths = mutation happened + max_len = max(len(old_val), len(new_val)) + for i in range(max_len): + old_item = old_val[i] if i < len(old_val) else None + new_item = new_val[i] if i < len(new_val) else None + _compare_values(f"{path}[{i}]", old_item, new_item) + return + + # For other types, just check if they're different objects + # (we don't care about non-tensor mutations) + + for module_name in snapshot.keys() | new_attrs.keys(): + old_module_attrs = snapshot.get(module_name, {}) + new_module_attrs = new_attrs.get(module_name, {}) + + for attr_name in old_module_attrs.keys() | new_module_attrs.keys(): + module_prefix = f"self.{module_name}." if module_name else "self." + full_path = f"{module_prefix}{attr_name}" + + old_val = old_module_attrs.get(attr_name) + new_val = new_module_attrs.get(attr_name) + _compare_values(full_path, old_val, new_val) + + return assigned_tensor_attributes + + new_attrs = _get_all_module_attributes(mod) + assigned_tensor_attributes = _collect_assigned_tensor_attributes( + snapshot, new_attrs + ) + # restore state of all attributes (including, e.g., of primitive types) + _restore_all_module_attributes(mod, snapshot) + + if assigned_tensor_attributes: + if len(assigned_tensor_attributes) > 1: + noun, verb = "attributes", "were" + else: + noun, verb = "attribute", "was" + warnings.warn( + f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " + "Such attributes must be registered as buffers using the `register_buffer` API " + "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer).", + stacklevel=2, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/functional_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/functional_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5af4fc9ee11955b4e6151f9602793c9076c48387 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/functional_utils.py @@ -0,0 +1,548 @@ +# mypy: allow-untyped-defs +""" +This file contains utilities related to functionalization in AOTAutograd: +1. converting to/from functional tensors +2. detecting Tensor mutations - both metadata and Tensor value +3. regenerating/replaying views from their base +4. checking if a graph is functional i.e. whether it contains any mutation ops +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor +from torch._C import _functionalization +from torch._logging import getArtifactLogger +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq, SymIntEqByExpr +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) + + +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") + + +def to_fun(t): + if isinstance(t, Tensor): + if is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + else: + return FunctionalTensor.to_functional(t) + else: + return t + + +def sync_functional_tensor(t): + if is_traceable_wrapper_subclass(t): + attrs, _ctx = t.__tensor_flatten__() # type: ignore[attr-defined] + for attr in attrs: + sync_functional_tensor(getattr(t, attr)) + else: + torch._sync(t) + + +# When subclasses are involved, t here will usually look something like: +# SubclassA(SubclassB(FunctionalTensor(_to_fun_tensor(FakeTensor)))) +def from_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + + if not isinstance(t, FunctionalTensor): + # quick sanity assert + if isinstance(t, torch.Tensor): + assert not torch._is_functional_tensor(t) # type: ignore[attr-defined] + return t + sync_functional_tensor(t) + return torch._from_functional_tensor(t.elem) + + +def is_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + t_attrs, _ = t.__tensor_flatten__() # type: ignore[attr-defined] + t_inners = [getattr(t, attr) for attr in t_attrs] + any_fun = any(is_fun(x) for x in t_inners) + all_fun = all(is_fun(x) for x in t_inners) + assert any_fun == all_fun + return any_fun + + return isinstance(t, FunctionalTensor) + + +# t here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +def has_data_mutation(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + return any(has_data_mutation(getattr(t, attr)) for attr in attrs) + else: + if isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_has_data_mutation(t.elem) # type: ignore[attr-defined] + return False + + +def are_all_mutations_hidden_from_autograd(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd. + return all( + are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs + ) + elif isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem) + else: + return False + + +def are_all_mutations_under_no_grad_or_inference_mode(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + return all( + are_all_mutations_under_no_grad_or_inference_mode(getattr(t, attr)) + for attr in attrs + ) + else: + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode( + t.elem + ) + + +def was_inductor_storage_resized(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + if any(was_inductor_storage_resized(getattr(t, attr)) for attr in attrs): + raise RuntimeError( + f"storage resizing is not supported on tensor subclass: {type(t)}" + ) + elif not isinstance(t, torch.Tensor): + return False + else: + assert isinstance(t, FunctionalTensor) + return torch._functionalize_was_inductor_storage_resized(t.elem) + + +# f_arg here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +# Assumption: arg promises to be the "original" tensor wrapped by f_arg +# Note: "storage mutations" coming from set_() are a type of metadata mutation. So: +# - check_only_storage_mutation=True: only return true if there was a storage mutation +# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation) +def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool): + if is_traceable_wrapper_subclass(f_arg): + attrs, _ = f_arg.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + f_inner_ts = [getattr(f_arg, attr) for attr in attrs] + inner_ts = [getattr(arg, attr) for attr in attrs] + return any( + has_metadata_mutation( + f_inner_t, + inner_t, + check_only_storage_mutation=check_only_storage_mutation, + ) + for f_inner_t, inner_t in zip(f_inner_ts, inner_ts) + ) + else: + if not isinstance(f_arg, torch.Tensor): + assert not isinstance(arg, torch.Tensor) + return False + assert isinstance(f_arg, FunctionalTensor) + assert isinstance(arg, FakeTensor) + + arg_after = torch._from_functional_tensor(f_arg.elem) + # This is true if the current tensor experienced at least one set_() call + maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem) # type: ignore[attr-defined] + # However, multiple set_() calls can cancel out. So we also check whether the + # storage of the tensor has changed. + # Note: if an input experienced two set_() calls that cancel out, **and** + # it experiences an data mutation, we pessimistically think that the set_() + # call is necessary here. We could in theory fix this, but this will + # hopefully never happen in user code, and is not needed for fsdp. + if is_sparse_any(arg): + # TODO:add sparse tensors support to functionalization + same_storages = False + else: + same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef( + arg_after.untyped_storage() + ) + has_storage_metadata_mutation = maybe_storage_changed and not same_storages + if check_only_storage_mutation: + return has_storage_metadata_mutation + + # storage metadata mutation is a type of metadata mutation, so return true if we saw one + if has_storage_metadata_mutation: + return True + + maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem) # type: ignore[attr-defined] + # This is true if the current tensor experienced at least one metadata mutation. + # So if false, we know there was no metadata mutation + if not maybe_metadata_mutated: + return False + + # However, multi metadata mutations can cancel out. + # So we also check if the concrete sizes/strides on the tensor have changed. + same_sizes = arg.shape == arg_after.shape + same_strides = arg.stride() == arg_after.stride() + same_offsets = arg.storage_offset() == arg_after.storage_offset() + has_metadata_mutation_ = maybe_metadata_mutated and not ( + same_sizes and same_strides and same_offsets + ) + # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call. + return has_metadata_mutation_ + + +def gen_alias_from_base( + aliased_base_tensor, + target_meta_tensor, + target_requires_grad, + target_view_meta_sequence: ViewMetaSequence | None = None, + *, + replay_views: bool, +): + # Patch the correct requires_grad field of the output tensor, depending on whether: + # (i) the reconstructed output (out) was came from a tensor that requires grad or not; + # and (ii) the concrete returned output does require grad or not. + def patch_requires_grad(out): + if aliased_base_tensor.requires_grad and not target_requires_grad: + out = out.detach() + elif not aliased_base_tensor.requires_grad and target_requires_grad: + out.requires_grad_(True) + return out + + # If provided, use the target functional tensor for replaying the views. + # + # In summary, we use the fact that FunctionalTensorWrapper saves the view + # functions applied to itself (collected during functionalization) so as + # to replay them (view functions) on the aliased_base_tensor. + if ( + replay_views + and target_view_meta_sequence is not None + and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence) + ): + out = _functionalization.apply_view_meta_sequence( + aliased_base_tensor, target_view_meta_sequence.sequence + ) + # If re-applying the ViewMeta sequence succeeded, there should be no more + # problems going forward. We just check we got to the target shape and + # patch requires_grad flag. + assert out.shape == target_meta_tensor.shape, ( + "incorrect out shape after application of ViewMeta sequence: " + f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)" + ) + return patch_requires_grad(out) + + # Try to do view-replay if possible. + # fall back to .as_strided() if we can't. + if target_meta_tensor._base is not None: + # The base that we want to replay our view off of might have a different shape than the view's original base. + b = target_meta_tensor._base + abt = aliased_base_tensor + # Don't unnecessarily call as_strided if nothing changed; as_strided's + # backward is poorly implemented and slow + if abt is not b and ( + abt.size() != b.size() + or abt.stride() != b.stride() + or abt.storage_offset() != b.storage_offset() + ): + reshaped_base_tensor = aliased_base_tensor.as_strided( + b.size(), b.stride(), b.storage_offset() + ) + else: + reshaped_base_tensor = aliased_base_tensor + out = target_meta_tensor._view_func(reshaped_base_tensor) + # This shape mismatch can happen due to a bug in inplace/view handling in autograd. + # Try putting a breakpoint here and running + # `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types` + # Also, https://github.com/pytorch/pytorch/issues/49825 + # + # As a stopgap, we'll fall back to as_strided. + if out is not None and out.shape == target_meta_tensor.shape: + return patch_requires_grad(out) + + size = target_meta_tensor.size() + stride = target_meta_tensor.stride() + storage_offset = target_meta_tensor.storage_offset() + if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex(): + aliased_out = torch.view_as_real(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex(): + aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + else: + aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset) + # For outputs aliasing inputs, we need to check if the requires-gradness has changed. + aliased_out = patch_requires_grad(aliased_out) + # For outputs aliasing inputs, we need to check if the dtype has changed. + # as_strided() is the "most generic" view, but it does not cover cross-dtype views + if aliased_out.dtype != target_meta_tensor.dtype: + aliased_out = aliased_out.view(target_meta_tensor.dtype) + return aliased_out + + +def has_same_metadata(t1, t2): + return ( + guard_or_false(sym_eq(t1.size(), t2.size())) + and guard_or_false(t1.layout == t2.layout) + and ( + is_sparse_any(t1) + or ( + guard_or_false(sym_eq(t1.stride(), t2.stride())) + and guard_or_false(t1.storage_offset() == t2.storage_offset()) + ) + ) + and t1.is_conj() == t2.is_conj() + and t1.is_neg() == t2.is_neg() + ) + + +@dataclass(frozen=True) +class MetadataKey: + """ + This should be equal whenever has_same_metadata would return True + """ + + size: tuple[SymIntEqByExpr, ...] + layout: torch.layout + is_sparse: bool + # these are empty when is_sparse + stride: tuple[SymIntEqByExpr, ...] | None + storage_offset: SymIntEqByExpr | None + is_conj: bool + is_neg: bool + + @staticmethod + def make(t): + is_sparse = is_sparse_any(t) + return MetadataKey( + size=tuple(SymIntEqByExpr(s) for s in t.size()), + layout=t.layout, + is_sparse=is_sparse, + stride=None if is_sparse else tuple(SymIntEqByExpr(s) for s in t.stride()), + storage_offset=None if is_sparse else SymIntEqByExpr(t.storage_offset()), + is_conj=t.is_conj(), + is_neg=t.is_neg(), + ) + + +# ViewMeta sequence wrapper for equality comparisons. +# +# Even though we can compare each ViewMeta instance, we compare the resulting +# tensor metadata, instead. That's because the creation of synthetic bases + the +# re-generation of input views might end-up creating a different sequence of +# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same +# metadata. +# +# Therefore, we store what the end result should look like as serializable +# metadata. +# +# When logging, this class should look like: +# +# ViewMetaSequence(view, select_int, slice_Tensor) +# +# i.e. a parenthesized list of view operations within that ViewMeta sequence. +class ViewMetaSequence: + def __init__(self, tensor: FunctionalTensor) -> None: + assert torch._is_functional_tensor(tensor.elem) + self.sequence = _functionalization.get_view_meta_sequence(tensor.elem) + self.metadata = MetadataKey.make(tensor) + + def __repr__(self) -> str: + suffix = len("_ViewMeta") + types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence) + return f"ViewMetaSequence({types})" + + def __eq__(self, other: object) -> bool: + # If other is None, then it probably means that we weren't able to recreate + # the ViewMeta sequence. One example is when we update the view metadata by + # calling: create_synthetic_base_metadata. + if other is None: + return True + + # Comparison against any other type is not implemented. + if not isinstance(other, ViewMetaSequence): + return NotImplemented + + return self.metadata == other.metadata + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed +# +# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization +# to confirm that inputs were not mutated when running the user's model with functionalization on. +# But when we have subclass inputs, we can't rely on that: +# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs +# a brand new subclass instance: we are calling __tensor_unflatten__, and going +# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor)) +def was_tensor_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed, +# but shares storage with the old input +def was_tensor_metadata_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg and StorageWeakRef( + arg.untyped_storage() + ) == StorageWeakRef(new_arg.untyped_storage()) + + +# Returns the number of detected copy_ +def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: + allowed_mutation_ops = [ + torch.ops.aten.copy_.default, + torch.ops.aten.set_.source_Tensor, + ] + if hasattr(torch.ops.fsdp, "copy_"): + allowed_mutation_ops.append(torch.ops.fsdp.copy_.default) + + placeholders = set() + mutation_count = 0 + # NB: It would also be nice to verify that the mutations all happen at the + # end, but we also do some administrative views after mutations so this + # isn't actually true. (TODO: Could this cause problems for Inductor?) + error = None + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target in allowed_mutation_ops: + # Can only copy_/set_ into an input + # this is mostly a hack to avoid failing XLA tests. + # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 + if "set_buffer_donor_" not in str(n.args[0]): + if n.args[0] not in placeholders: + error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + mutation_count += 1 + else: + if n.target._schema.is_mutable: + error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + return error, mutation_count + + +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: + error, mutation_count = _is_functional_graph(fx_g) + assert error is None, error + return mutation_count + + +def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None: + placeholders = set() + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target is torch.ops.aten.copy_.default: + # Can only copy_ into an input, and can only do so once + if "set_buffer_donor_" not in str(n.args[0]): + assert n.args[0] in placeholders, ( + f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + ) + placeholders.remove(n.args[0]) + copy_from_node = n.args[1] + # Pre-condition: every node has a "stack_trace" field in its meta, + # but copy_() nodes do not (since we manually added them during functionalization). + # Instead, we manually propagate here. + if "stack_trace" in copy_from_node.meta: + n.meta["stack_trace"] = copy_from_node.meta["stack_trace"] + + +def _check_if_mutation_can_be_in_graph( + keep_input_mutations: bool, + mutates_data, + mutates_metadata, + mutations_hidden_from_autograd, + mutations_under_no_grad_or_inference_mode, + mutates_storage_metadata, + mutation_inductor_storage_resize, + requires_grad, +): + if keep_input_mutations: + in_graph = ( + mutates_data or mutates_storage_metadata or mutation_inductor_storage_resize + ) and ( + (not mutates_metadata and not requires_grad) + or mutations_hidden_from_autograd + or mutations_under_no_grad_or_inference_mode + ) + else: + in_graph = False + # See Note [set_() Input Mutations in AOTAutograd] + # If there was a `set_()`, we require that all mutations were under no_grad, + # so we can (safely) emit the set_() in the graph at runtime + # resize_() gets the same treatment + if mutation_inductor_storage_resize or mutates_storage_metadata: + op_name = "resize_" if mutation_inductor_storage_resize else "set_" + assert in_graph, f"""\ +Encountered a {op_name} on a graph input, but the input has other mutations that we cannot +keep in the graph. This is not supported today. Current state: + keep_input_mutations={keep_input_mutations} + mutates_data={mutates_data} + mutates_metadata={mutates_metadata} + mutations_hidden_from_autograd={mutations_hidden_from_autograd} + mutations_under_no_grad_or_inference_mode={mutations_under_no_grad_or_inference_mode} + mutation_inductor_storage_resize={mutation_inductor_storage_resize} + requires_grad={requires_grad}""" + return in_graph diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/fx_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..491cf3e1fe8cfad65cea4394b0eb2bcbe9832910 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/fx_utils.py @@ -0,0 +1,317 @@ +""" +This module contains utility functions for working with joint FX graphs with descriptors +that are produced by AOTAutograd. They will NOT work on generic FX graphs. See also +:func:`torch._functorch.aot_autograd.aot_export_joint_with_descriptors`. We also +recommend reading :mod:torch._functorch._aot_autograd.descriptors`. +""" + +from typing import NoReturn, Optional, Union + +import torch.fx as fx + +from .descriptors import ( + AOTInput, + AOTOutput, + BufferAOTInput, + DifferentiableAOTInput, + DifferentiableAOTOutput, + GradAOTOutput, + ParamAOTInput, + PlainAOTInput, + PlainAOTOutput, + SubclassGetAttrAOTInput, + SubclassGetAttrAOTOutput, + TangentAOTInput, +) + + +def _raise_autograd_subclass_not_implemented( + n: fx.Node, desc: Union[AOTInput, AOTOutput] +) -> NoReturn: + raise RuntimeError( + "Subclasses are currently not supported by this function, but a desugared subclass input " + f"was found at {n} ({desc}). The problem is " + "that there may not necessarily be a 1-1 correspondence between primals/tangents/outputs/grads " + "when subclasses are involved: for example, the primal might be a plain tensor " + "but the tangent a tensor subclass that desugared into multiple plain tensors. " + "It is not clear what exactly you would like this function to do in this case " + "(Collect all nodes for the subclass together? Match up the inner nodes if " + "subclasses match exactly?) If you have a concrete use case, please file an " + "issue so we can understand it and design an API that works for your case." + ) + + +def get_all_input_and_grad_nodes( + g: fx.Graph, +) -> dict[DifferentiableAOTInput, tuple[fx.Node, Optional[fx.Node]]]: + """ + Given a joint graph with descriptors (meta['desc'] on placeholders and + output), returns the node for every input and its corresponding grad + output node if it exists. These tuples are in a dict that is indexed by + the AOTInput descriptor that describes the input. + + NB: *all* forward tensor inputs are returned, including non-differentiable + inputs (which simply have a None grad), so it is safe to use this function + to perform operations on all inputs. (Non-tensor inputs like symbolic + integers, tokens or RNG state are NOT traversed by this function.) + + Args: + g: The FX joint graph with descriptors + + Returns: + A dictionary mapping each DifferentiableAOTInput descriptor to a tuple + containing: + - The input node itself + - The grad (output) node if it exists, None otherwise + + Raises: + RuntimeError: If the joint graph has subclass tensor inputs/outputs; this + is not supported by API as there is not necessarily a 1-1 correspondence + between inputs and grads when subclasses are involved. + """ + input_index: dict[DifferentiableAOTInput, tuple[fx.Node, Optional[fx.Node]]] = {} + for n in g.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + # Skip inputs that cannot possibly be differentiable + if not isinstance(desc, DifferentiableAOTInput): + continue + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_autograd_subclass_not_implemented(n, desc) + # pyrefly: ignore [unsupported-operation] + input_index[desc] = (n, None) + elif n.op == "output": + assert "desc" in n.meta, (n, n.meta) + desc = n.meta["desc"] + for sub_n, sub_desc in zip(n.args[0], desc): + if isinstance(sub_desc, SubclassGetAttrAOTOutput): + _raise_autograd_subclass_not_implemented(sub_n, sub_desc) + if isinstance(sub_desc, GradAOTOutput): + inp, grad = input_index[sub_desc.grad_of] + assert grad is None, (sub_n, sub_desc, input_index) + input_index[sub_desc.grad_of] = (inp, sub_n) + return input_index + + +def get_all_output_and_tangent_nodes( + g: fx.Graph, +) -> dict[DifferentiableAOTOutput, tuple[fx.Node, Optional[fx.Node]]]: + """Get all output nodes and their corresponding tangent nodes from a joint graph. + + Similar to get_all_input_and_grad_nodes, but returns output nodes paired with + their tangent nodes (if they exist). This function traverses the graph to find + all differentiable outputs and matches them with their corresponding tangent + inputs used in forward-mode autodiff. + + NB: *all* forward tensor output sare turned, including non-differentiable outputs, + so you can use this function to perform operations on all outputs. + + Args: + g: The FX joint graph with descriptors + + Returns: + A dictionary mapping each DifferentiableAOTOutput descriptor to a tuple + containing: + - The output node itself + - The tangent (input) node if it exists, None otherwise + + Raises: + RuntimeError: If the joint graph has subclass tensor inputs/outputs; this + is not supported by API as there is not necessarily a 1-1 correspondence + between outputs and tangents when subclasses are involved. + """ + output_index: dict[DifferentiableAOTOutput, tuple[fx.Node, Optional[fx.Node]]] = {} + for n in g.nodes: + if n.op == "output": + desc = n.meta["desc"] + for sub_n, sub_d in zip(n.args[0], desc): + # Skip outputs that cannot possibly be differentiable + if not isinstance(sub_d, DifferentiableAOTOutput): + continue + if isinstance(sub_d, SubclassGetAttrAOTOutput): + _raise_autograd_subclass_not_implemented(sub_n, sub_d) + # pyrefly: ignore [unsupported-operation] + output_index[sub_d] = (sub_n, None) + for n in g.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_autograd_subclass_not_implemented(n, desc) + if isinstance(desc, TangentAOTInput): + out, tangent = output_index[desc.output] + assert tangent is None, (n, desc, output_index) + output_index[desc.output] = (out, n) + return output_index + + +def get_param_and_grad_nodes( + graph: fx.Graph, +) -> dict[ParamAOTInput, tuple[fx.Node, Optional[fx.Node]]]: + """Get parameter nodes and their corresponding gradient nodes from a joint graph. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping each ParamAOTInput descriptor to a tuple containing: + - The parameter input node + - The gradient (output) node if it exists, None otherwise + """ + return { + desc: (n, g) + for desc, (n, g) in get_all_input_and_grad_nodes(graph).items() + if isinstance(desc, ParamAOTInput) + } + + +def get_plain_input_and_grad_nodes( + graph: fx.Graph, +) -> dict[PlainAOTInput, tuple[fx.Node, Optional[fx.Node]]]: + """Get plain input nodes and their corresponding gradient nodes from a joint graph. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping each PlainAOTInput descriptor to a tuple containing: + - The plain input node + - The gradient (output) node if it exists, None otherwise + """ + return { + desc: (n, g) + for desc, (n, g) in get_all_input_and_grad_nodes(graph).items() + if isinstance(desc, PlainAOTInput) + } + + +def get_plain_output_and_tangent_nodes( + graph: fx.Graph, +) -> dict[PlainAOTOutput, tuple[fx.Node, Optional[fx.Node]]]: + """Get plain output nodes and their corresponding tangent nodes from a joint graph. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping each PlainAOTOutput descriptor to a tuple containing: + - The plain output node + - The tangent (input) node if it exists, None otherwise + """ + return { + desc: (n, g) + for desc, (n, g) in get_all_output_and_tangent_nodes(graph).items() + if isinstance(desc, PlainAOTOutput) + } + + +def _raise_fqn_subclass_not_implemented( + n: fx.Node, desc: Union[AOTInput, AOTOutput] +) -> NoReturn: + raise RuntimeError( + "Subclasses are currently not supported by this function, but a desugared subclass input " + f"was found at {n} ({desc}). The problem is " + "that there may not necessarily be a 1-1 correspondence between a FQN and a plain tensor " + "when subclasses are involved: for example, a parameter that is a subclass " + "would desugar into multiple plain tensors, which we can't uniquely assign the " + "FQN to. It's not clear what you want the API to do in this case: do you want to " + "instead return a struct of nodes showing how to assemble the subclass? But you " + "don't (directly) have the metadata for the subclass? If you have a concrete use " + "case, please file an issue so we can understand it and design an API that works for your case." + ) + + +def get_named_param_nodes(graph: fx.Graph) -> dict[str, fx.Node]: + """Get parameter nodes mapped by their fully qualified names. + + This function traverses the graph to find all parameter input nodes and + returns them in a dictionary where keys are the parameter names (FQNs) + and values are the corresponding FX nodes. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping parameter names (str) to their corresponding FX nodes. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + with subclasses a FQN does not necessarily map to a single plain tensor. + """ + r = {} + for n in graph.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_fqn_subclass_not_implemented(n, desc) + elif isinstance(desc, ParamAOTInput): + r[desc.target] = n + return r + + +def get_named_buffer_nodes(graph: fx.Graph) -> dict[str, fx.Node]: + """Get buffer nodes mapped by their fully qualified names. + + This function traverses the graph to find all buffer input nodes and + returns them in a dictionary where keys are the buffer names (FQNs) + and values are the corresponding FX nodes. + + Args: + graph: The FX joint graph with descriptors + + Returns: + A dictionary mapping buffer names (str) to their corresponding FX nodes. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + with subclasses a FQN does not necessarily map to a single plain tensor. + """ + r = {} + for n in graph.nodes: + if n.op == "placeholder": + desc = n.meta["desc"] + if isinstance(desc, SubclassGetAttrAOTInput): + _raise_fqn_subclass_not_implemented(n, desc) + elif isinstance(desc, BufferAOTInput): + r[desc.target] = n + return r + + +def get_param_nodes(graph: fx.Graph) -> list[fx.Node]: + """Get all parameter nodes from a graph as a list. + + You can rely on this providing the correct order of parameters you need + to feed into the joint graph (at the very beginning of the argument list, + before buffers). + + Args: + graph: The FX joint graph with descriptors + + Returns: + A list of FX nodes representing all parameters in the graph. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + it is not clear if you wanted each individual constituent piece of the + subclasses, or have them grouped up in some way. + """ + return list(get_named_param_nodes(graph).values()) + + +def get_buffer_nodes(graph: fx.Graph) -> list[fx.Node]: + """Get all buffer nodes from a graph as a list. + + You can rely on this providing the correct order of buffers you need + to feed into the joint graph (after parameters). + + Args: + graph: The FX joint graph with descriptors + + Returns: + A list of FX nodes representing all buffers in the graph. + + Raises: + RuntimeError: If subclass tensors are encountered (not yet supported), as + it is not clear if you wanted each individual constituent piece of the + subclasses, or have them grouped up in some way. + """ + return list(get_named_buffer_nodes(graph).values()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef84cb488604c1c55b36890f270f3255a8ee138 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -0,0 +1,1395 @@ +# mypy: allow-untyped-defs +""" +This module is responsible for transforming functions to be traced into a form +that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) +to handle. + +It does so by: +1. functionalization (including RNG functionalzation) +2. creating a joint graph when required +3. transforming mutations into extra outputs +4. dispatching subclasses +""" + +import warnings +from collections.abc import Callable +from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext +from dataclasses import dataclass +from typing import Any, Optional, TypeVar, Union +from unittest.mock import patch + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import Tensor +from torch._decomp.decompositions_for_rng import PhiloxStateTracker +from torch._guards import detect_fake_mode +from torch._prims_common import CUDARngStateHelper +from torch.fx.experimental.proxy_tensor import ( + _proxy_tensor_disable_update_tensor_tracker, + get_proxy_mode, + maybe_disable_thunkify, + maybe_enable_thunkify, +) +from torch.fx.experimental.symbolic_shapes import ( + guard_or_true, + PropagateUnbackedSymInts, + sym_eq, +) +from torch.nn.utils import stateless +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._pytree import TreeSpec + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .descriptors import ( + AOTInput, + AOTOutput, + BackwardTokenAOTOutput, + ForwardTokenAOTInput, + ForwardTokenAOTOutput, + GradAOTOutput, + InputMutationAOTOutput, + IntermediateBaseAOTOutput, + PhiloxBackwardBaseOffsetAOTInput, + PhiloxBackwardSeedAOTInput, + PhiloxForwardBaseOffsetAOTInput, + PhiloxForwardSeedAOTInput, + PhiloxUpdatedBackwardOffsetAOTOutput, + PhiloxUpdatedForwardOffsetAOTOutput, +) +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + is_fun, + sync_functional_tensor, + to_fun, + was_inductor_storage_resized, +) +from .logging_utils import setup_stacktrace_preservation_hooks +from .schemas import ( + AOTConfig, + FxValue, + JointTraceFn, + MutationType, + OutputType, + PreppedForAutogradTraceFn, + SubclassMeta, + SubclassTracingInfo, + TraceFn, + ViewAndMutationMeta, +) +from .subclass_utils import ( + create_subclass_meta, + remap_unwrapped_subclass_arg_indices, + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, +) +from .utils import ( + call_and_expect_output_descs, + maybe_to_fresh_input, + simple_wraps, + without_output_descs, +) + + +# This function returns a new function that returns mutated inputs as outputs. +# if keep_data_input_mutations is set, then we assume that data-only mutations +# will be left in the graph, and we only return metadata-mutated inputs as outputs. +def fn_input_mutations_to_outputs( + fn: Callable, + args_descs: list[AOTInput], + meta: ViewAndMutationMeta, + keep_data_input_mutations: bool, +) -> Any: + @simple_wraps(fn) + def inner_fn(*args): + outs, outs_descs = call_and_expect_output_descs(fn, args) + assert len(meta.output_info) == len(outs) + # The compiled fw will return mutated input tensors, *including* metadata-only mutation. + # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. + # (because data-only input mutations are handled directly in the compiled graph) + mutated_input_pairs = [ + (x, InputMutationAOTOutput(src)) + for (i, (x, src)) in enumerate(zip(args, args_descs)) + if i in meta.mutated_inp_runtime_indices + ] + if mutated_input_pairs: + mutated_inputs_to_return, mutated_inputs_to_return_descs = zip( + *mutated_input_pairs + ) + else: + mutated_inputs_to_return, mutated_inputs_to_return_descs = (), () + return ( + (*mutated_inputs_to_return, *outs), + (*mutated_inputs_to_return_descs, *outs_descs), + ) + + return inner_fn + + +@contextmanager +def disable_autocast(): + with ExitStack() as stack: + autocast_enabled_devices = torch._C._autocast_supported_devices() + for device_type in autocast_enabled_devices: + if hasattr(torch, device_type): + stack.enter_context(torch.amp.autocast(device_type, enabled=False)) + yield + + +# This function takes in a fn with external aliasing and mutation, +# and returns a new fn with no external aliasing and mutation, +# as needed for autograd. +# The main transformations are: +# - Return mutated inputs as extra outputs +# - Clone mutated inputs that require gradients, +# because autograd will require us to pass the pre-mutated inputs into autograd.grad +# - Return intermediate bases of outputs as additional outputs, +# needed to appease autograd.Function +# The new function returns: +# (1) The updated outputs +# (2) A boolean mask of len(new_fn_outputs), +# that can be used to tell autograd.grad which outputs should get tangents +# if we trace the backward. +def fn_prepped_for_autograd( + fn: TraceFn, + args_descs: list[AOTInput], + meta: ViewAndMutationMeta, + aot_config: AOTConfig, +) -> PreppedForAutogradTraceFn: + @simple_wraps(fn) + def inner_fn(*args): + args_maybe_cloned = [ + maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args) + ] + + outs, outs_descs = call_and_expect_output_descs(fn, args_maybe_cloned) + assert isinstance(outs, (tuple, list)) + outs = list(outs) + assert len(meta.output_info) == len(outs) + + mutated_input_pairs = [ + (x, InputMutationAOTOutput(src)) + for (i, (x, src)) in enumerate(zip(args_maybe_cloned, args_descs)) + if i in meta.mutated_inp_runtime_indices + ] + if mutated_input_pairs: + mutated_inputs_to_return, mutated_inputs_to_return_descs = zip( + *mutated_input_pairs + ) + else: + mutated_inputs_to_return, mutated_inputs_to_return_descs = (), () + + intermediate_bases = [] + intermediate_bases_descs = [] + for o, info, o_desc in zip(outs, meta.output_info, outs_descs): + if info.output_type == OutputType.alias_of_intermediate_save_as_output: + assert isinstance(o, torch.Tensor), ( + f"Expected tensor for intermediate base, got {type(o)}" + ) + intermediate_bases.append(o._base) + intermediate_bases_descs.append(IntermediateBaseAOTOutput(o_desc)) + + assert meta.num_intermediate_bases == len(intermediate_bases) + + # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) + fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases + fw_outs_to_return_descs = ( + *mutated_inputs_to_return_descs, + *outs_descs, + *intermediate_bases_descs, + ) + + # Also return a boolean mask specifying which outputs to this function will be used as tangents + mutated_inputs_grad_mask = [ + meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data + and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad + for (i, x) in enumerate(mutated_inputs_to_return) + ] + + # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw + # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, + # which we *should* send to grad() + output_grad_mask = [ + meta.output_info[i].output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + # Also, only tensor outputs should participate in the backward + # (in particular, Symint outputs in the forward graph shouldn't get tangents) + and issubclass(meta.output_info[i].raw_type, Tensor) + and meta.output_info[i].requires_grad + for (i, x) in enumerate(outs) + ] + + intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))] + + out_grad_mask = ( + mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask + ) + assert len(out_grad_mask) == len(fw_outs_to_return) + + # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) + # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) + # This is annoying: our joint function needs to be aware of functionalization + # (syncing mutated inputs before calling autograd.grad()) + # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner. + if not aot_config.disable_functionalization: + for arg in args_maybe_cloned: + if not isinstance(arg, Tensor): + continue + sync_functional_tensor(arg) + + return (fw_outs_to_return, out_grad_mask), ( + fw_outs_to_return_descs, + out_grad_mask, + ) + + return inner_fn + + +@dataclass +class JointFnHandle: + post_forward: Optional[Callable] = None + + +# Given a fn, computes the joint. +# NOTE: fn is expects the following behavior: +# (1) fn() needs to return a tuple of (outs, mask), +# where `mask` tells us which outputs are meant to have tangents. +# we don't know this info automatically, because we don't actually want to blindly +# compute tangents for every output that requires grad. +# Specifically, outputs that alias inputs won't participate in the backward and get tangents. +# (2) fn() cannot mutate any inputs that require gradient. +# otherwise, when we compute autograd.grad(), we will not take those input mutations into account +# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) +def create_joint( + fn: Any, # PreppedForAutogradTraceFn + primals_descs: Optional[list[AOTInput]] = None, + *, + aot_config: AOTConfig, +) -> Any: # JointTraceFn + joint_fn_handle = JointFnHandle() + + # post_forward + # NB: this type is inaccurate when primals_descs is None + @simple_wraps(fn) + def inner_fn( + primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[Optional[Tensor]]], + tuple[list[AOTOutput], list[Optional[AOTOutput]]], + ]: + outs_descs = None + if primals_descs is None: + outs, tangent_mask = fn(*primals) + assert not pytree.tree_any(lambda x: isinstance(x, AOTOutput), tangent_mask) + else: + (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( + fn, primals + ) + mode = get_proxy_mode() + assert mode is not None, "Expected non-None proxy mode" + for node in mode.tracer.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" + + # TODO: I think this hook can also be eliminated now + if joint_fn_handle and joint_fn_handle.post_forward: + joint_fn_handle.post_forward(primals) + + assert len(tangent_mask) == len(outs) + outs_to_grad = [ + o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent + ] + assert len(outs_to_grad) == len(tangents) + + # Get the inputs that need gradients + grad_primals: list[torch.Tensor] = [] + inputs_needs_grads = [] + # Note that we're not using primals here, + # being carefully not to pass any mutated inputs into autograd.grad() + for p in primals: + if isinstance(p, Tensor) and p.requires_grad: + inputs_needs_grads.append(True) + assert isinstance(p, torch.Tensor) # Help mypy understand the type + grad_primals.append(p) + else: + inputs_needs_grads.append(False) + + # Get the outputs that need gradients + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(outs_to_grad, tangents): + if isinstance(out, Tensor) and out.requires_grad: + # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 + # The issue is that we are sensitive to decomps that don't accurately maintain + # their output's _base.shape compared to eager mode, and this helps mitigate a bit. + # The guard_or_true also sketchy; if unbacked + # symints are involved, we're just going to assume that the + # decomps setup the base shape correctly + + # Return out if the result of out.shape==tangent.shape is unknown or known to be true. + # otherwise if its a known false return out.view(tangent.shape). + # tangent should also be a tensor since it corresponds to a tensor output + assert isinstance(tangent, torch.Tensor), ( + f"Expected tensor tangent, got {type(tangent)}" + ) + needed_outs.append( + out + if guard_or_true(sym_eq(out.shape, tangent.shape)) + else out.view(tangent.shape) + ) + needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + + if config.functionalize_rng_ops: + PhiloxStateTracker.mark_beginning_of_backward() + backward_out: tuple[Tensor, ...] = () + # Call the backwards pass + if grad_primals: + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + if functional_tensor_mode is not None: + # Side-Effect Tokens: + # We want to have independent chains of tokens for forward and backward. + # functional_tensor_mode._tokens is used by both. + # We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output, + # to return them as joint graph outputs. + # We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward. + # Joint graph tracing allows tokens discovery, + # So all the tokens in backward will be created and added as a graph inputs during tracing. + functional_tensor_mode._tokens_forward_output = ( + functional_tensor_mode._tokens + ) + functional_tensor_mode._tokens = {} + + with ( + set_partitioner_tag_is_backward(), + fx_traceback.preserve_node_meta(), + ExitStack() as stack, + ): + backward_pass_autocast = torch._functorch.config.backward_pass_autocast + if backward_pass_autocast == "same_as_forward": + # Use the ambient autocast mode(s) + pass + elif backward_pass_autocast == "off": + stack.enter_context(disable_autocast()) + else: + # Disable autocast, then enable anything in `backward_pass_autocast`. + stack.enter_context(disable_autocast()) + assert isinstance(backward_pass_autocast, list) + for kwargs in backward_pass_autocast: + assert isinstance(kwargs, dict) + stack.enter_context(torch.amp.autocast(**kwargs)) + + # for full graph export, we always export a joint graph where we assume no tangents are needed. + if aot_config.no_tangents: + assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1 + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + allow_unused=True, + ) + else: + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + final_outs = ( + outs, + [next(backward_out_iter) if i else None for i in inputs_needs_grads], + ) + if primals_descs is None: + return final_outs # type: ignore[return-value] + assert outs_descs is not None + return final_outs, ( + outs_descs, + [ + # TODO: ideally we do know this is DifferentiableAOTInput + # but this is quite an involved refactor + GradAOTOutput(desc) if i else None # type: ignore[arg-type] + for i, desc in zip(inputs_needs_grads, primals_descs) + ], + ) + + @simple_wraps(inner_fn) + def inner_fn_with_anomaly( + primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[Optional[Tensor]]], + tuple[list[AOTOutput], list[Optional[AOTOutput]]], + ]: + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.") + with torch.autograd.detect_anomaly(check_nan=False): + return inner_fn(primals, tangents) + + def joint_helper(primals, tangents): + return inner_fn_with_anomaly(primals, tangents) + + joint_helper.handle = joint_fn_handle # type: ignore[attr-defined] + + return joint_helper + + +def create_functionalized_rng_ops_wrapper( + func, args, args_descs, trace_joint=True +) -> Any: + # Functionalization of rng ops changes the calling convention of the joint graph. + # It goes from (primals, tangents) to (seed, offset, primals, tangents) + # At runtime, we pass on the current seed and offset. This is hidden from + # the user. + fake_mode_det = detect_fake_mode() + fake_mode: AbstractContextManager[Any] = nullcontext() + if fake_mode_det is not None: + fake_mode = fake_mode_det + + def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): + out = PhiloxStateTracker.get_state_as_tensor() + return out + + def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"): + PhiloxStateTracker.set_state_from_tensor(x) + + def append_rng_offsets(outs, outs_descs): + if trace_joint: + # outs signature before: Tuple(fwd_outputs), Tuple(bwd_outputs) + # outs signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset) + return ( + ( + (*outs[0], PhiloxStateTracker.get_updated_fwd_offset()), + (*outs[1], PhiloxStateTracker.get_updated_bwd_offset()), + ), + ( + (*outs_descs[0], PhiloxUpdatedForwardOffsetAOTOutput()), + (*outs_descs[1], PhiloxUpdatedBackwardOffsetAOTOutput()), + ), + ) + else: + # outs signature before: Tuple(fwd_outputs) + # outs signature after: Tuple(fwd_outputs, new_fwd_rng_offset) + return ( + (*outs, PhiloxStateTracker.get_updated_fwd_offset()), + (*outs_descs, PhiloxUpdatedForwardOffsetAOTOutput()), + ) + + def traced_joint( + primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset + ): + with ( + patch("torch.cuda.get_rng_state", override_get_rng_state), + patch("torch.cuda.set_rng_state", override_set_rng_state), + ): + return append_rng_offsets(*func(primals, tangents)) + + def traced_forward(*primals_fwd_seed_fwd_base_offset): + # The signature is (*primals, seed, offset) + with ( + patch("torch.cuda.get_rng_state", override_get_rng_state), + patch("torch.cuda.set_rng_state", override_set_rng_state), + ): + return append_rng_offsets(*func(*primals_fwd_seed_fwd_base_offset[:-2])) + + if trace_joint: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward") + return ( + traced_joint, + ( + *args, + fwd_seed, + fwd_base_offset, + bwd_seed, + bwd_base_offset, + ), + ( + *args_descs, + PhiloxForwardSeedAOTInput(), + PhiloxForwardBaseOffsetAOTInput(), + PhiloxBackwardSeedAOTInput(), + PhiloxBackwardBaseOffsetAOTInput(), + ), + ) + else: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + return ( + traced_forward, + (*args, fwd_seed, fwd_base_offset), + ( + *args_descs, + PhiloxForwardSeedAOTInput(), + PhiloxForwardBaseOffsetAOTInput(), + ), + ) + + +@contextmanager +def set_partitioner_tag(tag: str): + meta_key = "partitioner_tag" + assert fx_traceback.has_preserved_node_meta() + + original_val = fx_traceback.current_meta.get(meta_key, None) + fx_traceback.current_meta[meta_key] = tag + try: + yield + finally: + fx_traceback.current_meta[meta_key] = original_val + + +def set_partitioner_tag_is_backward(): + return set_partitioner_tag("is_backward") + + +def set_partitioner_tag_must_be_in_backward(): + return set_partitioner_tag("must_be_in_backward") + + +def set_partitioner_tag_must_be_in_forward(): + return set_partitioner_tag("must_be_in_forward") + + +@dataclass +class MutationCounters: + mc_data: int + mc_storage: int + mc_inductor_storage_resized: int + + +T = TypeVar("T") + + +def sc_visit( + t, fn: Callable[[Tensor], T], reduce_fn: Callable[[T, T], T], accum_init: T +) -> T: + if not is_traceable_wrapper_subclass(t): + return fn(t) + + accum = accum_init + + def visit(e): + if not is_traceable_wrapper_subclass(e): + nonlocal accum + accum = reduce_fn(accum, fn(e)) + return + + for a in e.__tensor_flatten__()[0]: + visit(getattr(e, a)) + + visit(t) + return accum + + +def _get_mutation_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_mutation_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_storage_changed_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_storage_changed_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_inductor_storage_resized_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_inductor_storage_resized_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_mutation_counters(t) -> MutationCounters: + return MutationCounters( + _get_mutation_counter(t), + _get_storage_changed_counter(t), + _get_inductor_storage_resized_counter(t), + ) + + +def apply_in_graph_mutations( + input_info, + inpt_old, + inpt_new, + f_inpt, + input_idx, + mcs: Optional[MutationCounters] = None, + applied_mcs: Optional[MutationCounters] = None, +): + assert input_info.mutation_type == MutationType.MUTATED_IN_GRAPH + # See Note [set_() Input Mutations in AOTAutograd] + # all mutations on the input must be under no_grad, so it is safe to put in the graph + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + if input_info.mutates_storage_metadata: + if mcs is None or mcs.mc_storage > applied_mcs.mc_storage: # type: ignore[union-attr] + with torch.no_grad(): + inpt_old.set_(inpt_new) + + # Note [Ordering of resize_() and set_()] + # Importantly: the common usage in FSDP is that we have a dummy parameter + # that sees a set_() and **Then** a resize_(). + # We must put those mutations into the graph in the same order, + # Since running them in the opposite order will have different behavior. + # We fully ban resize_() followed by set_() for now, although in principal + # we could support this + if input_info.mutation_inductor_storage_resize: + if ( + mcs is None + or mcs.mc_inductor_storage_resized > applied_mcs.mc_inductor_storage_resized # type: ignore[union-attr] + ): + # resizing is not supported on subclasses (we error earlier if this happens) + from torch._subclasses.functional_tensor import FunctionalTensor + + assert isinstance(f_inpt, FunctionalTensor) + old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + f_inpt.elem, before=True + ) + new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + f_inpt.elem, before=False + ) + if old_storage_size != new_storage_size: + assert old_storage_size == 0 or new_storage_size == 0, f"""\ + Encosize during tracing on input {input_idx}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} + We oresizing on graph inputs as long as the input either starts or ends with a storage size of 0 + (thee for FSDP)""" + torch.ops.inductor.resize_storage_bytes_(inpt_old, new_storage_size) + if new_storage_size == 0: + # Even if we marked the input as having a data mutation (thus needing a copy_()), + # We should **ignore** it if our input has no storage + # (this can happen if, e.g. we temporarily resize our input, copy data into it, + # and resize it back down to zero) + return + + # Optimization: if the copy_() is a no-op then don't include it in the graph. + # In theory inductor could optimize this away, however in fsdp, we end up with + # param.copy_(param), where param is a zero-storage-size tensor, + # and running this op in eager mode (using the aot_eager backend) will result in a segfault. + # So we may as well optimize it away here. + if inpt_old is inpt_new: + # (This check needs to be done after putting resize_() in the graph, + # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) + return + # We found an input that had a (data-only) mutation. + # Since keep_input_mutations is set, we need to faithfully apply a copy_() + # so the compiler will see the input mutation in the graph. + + if not input_info.mutates_data: + return + + if mcs is not None and mcs.mc_data <= applied_mcs.mc_data: # type: ignore[union-attr] + return + + if input_info.mutations_hidden_from_autograd: + # Hidden from autograd = run under no_grad, **and** don't bump VC + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: + inpt_old.copy_(inpt_new) + elif input_info.mutations_under_no_grad_or_inference_mode: + # Under no_grad = run under no_grad (we still bump the VC though) + # (inference_mode will also bump the VC, as long as the tensor in question + # was created outside of inference_mode) + + with torch.no_grad(): + inpt_old.copy_(inpt_new) + else: + inpt_old.copy_(inpt_new) + + +# This creates the final function that we want to trace using make_fx(), +# in both aot_dispatch_autograd and aot_dispatch_base. +# Preconditions: +# - fn corresponds to the user's fw function +# - fn arguments have been flattened, duplicate arguments have been handled +# - In the returned function, the "primals" arguments *includes* synthetic bases. +# This function does the work of functionalizing the input function, +# and performing copy_() calls at the end of the function if `keep_input_mutations` is set. +# The function returned has signature that is either: +# (1) "traced_fn(primals: List[Any])" if trace_joint is False +# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True +# Returns a new (functionalized) function, and updated arguments to call it with. +def create_functionalized_fn( + fn, + args, + args_descs, + *, + meta: ViewAndMutationMeta, + aot_config: AOTConfig, + trace_joint: bool, + joint_fn_handle: Optional[JointFnHandle] = None, +) -> Any: + primals_after_forward = None + f_args_after_forward = None + f_args_mutation_counters_after_forward: Optional[list[MutationCounters]] = None + inputs_mutated_in_graph = [ + info.mutation_type == MutationType.MUTATED_IN_GRAPH for info in meta.input_info + ] + has_input_mutated_in_graph = any(inputs_mutated_in_graph) + + @simple_wraps(fn) + def _functionalized_f_helper( + *args: list[FxValue], + ) -> tuple[tuple[list[FxValue], list[Tensor]], list[Optional[AOTOutput]]]: + with maybe_enable_thunkify(): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # The functionalization code here can potentially trigger traces + # into the graph, but we'd prefer to NOT do this, because if we + # trace them now, we will end up with FX nodes that don't have + # module stack annotations, which makes unflattener unhappy. + # Wrap inputs into functional wrappers + f_args = pytree.tree_map(to_fun, args) + + if trace_joint and has_input_mutated_in_graph and joint_fn_handle: + # TODO(ivankobzarev): Support fw and bw mutations for subclasses + def _post_forward(primals): + nonlocal primals_after_forward + primals_after_forward = pytree.tree_map(from_fun, primals) + nonlocal f_args_after_forward + f_args_after_forward = f_args[0] + nonlocal f_args_mutation_counters_after_forward + f_args_mutation_counters_after_forward = [ + MutationCounters(-1, -1, -1) + if not inputs_mutated_in_graph[i] + else _get_mutation_counters(f_arg) + for i, f_arg in enumerate(f_args_after_forward) + ] + + joint_fn_handle.post_forward = _post_forward + + # Run the joint + f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args) + + if trace_joint: + # We support a limited amount of mutation of graph inputs during the backward pass. + # (This is used e.g. by Float8, which needs to update buffers during the backward pass) + # Here, we perform extra checks for primals that were mutated in the **backward** + # We're doing the checks here instead of doing them with the rest of the input mutation handling because: + # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened + # during the forward, because the handling is different: some input mutations from the forward + # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same + # types of mutations in the backward we would need a bw-only runtime epilogue. + # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in + # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would + # require an extra round of tracing though, so it's more efficient to do in-line here. + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) + ) + # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) + primals_before = args[0] + primals_after = pytree.tree_map(from_fun, f_args[0]) + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip(f_args[0], primals_before, primals_after, meta.input_info) + ): + # Store information about mutations in joint(for backward analysis) + joint_mutates_data = has_data_mutation(f_inpt) + + joint_mutates_metadata = has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ) + + # Ban metadata mutations on fw inputs during the bw + if not inpt_info.mutates_metadata: + assert not joint_mutates_metadata, ( + "Found a graph input that had its metadata mutated in the backward. This is not supported" + ) + + # Ban storage resizing on fw inputs during the bw + if not inpt_info.mutation_inductor_storage_resize: + assert not was_inductor_storage_resized(f_inpt), ( + "Found a graph input that had storage resizing in the backward. This is not supported" + ) + + # Allow data mutations on fw inputs during the bw, but only if they do not require grad + # So we can guarantee that we can keep the mutations in the graph + if ( + joint_mutates_data + and not inpt_info.mutates_data + and not inpt_info.mutates_storage_metadata + ): + # Not banning here mutations on inpt_info.requires_grad - + # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) + # Add node meta for copy_ for partitioner that this node should be in backward graph. + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_backward(), + ): + # before and after should be tensors if we're calling copy_ on them + assert isinstance(before, torch.Tensor) and isinstance( + after, torch.Tensor + ) + before.copy_(after) + meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( + idx + ) + # Now that we covered mutations to *forward* inputs during the backward, + # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). + # Today, we will just error in all cases of this happening unless someone needs us to support it. + tangents_before = args[1] + tangents_after = pytree.tree_map(from_fun, f_args[1]) + for f_inpt, before, after in zip( + f_args[1], tangents_before, tangents_after + ): + assert not has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ), ( + "Found an input to the backward that had metadata mutated during the backward pass. This is not supported" + ) + if has_data_mutation(f_inpt): + can_be_in_graph = _check_if_mutation_can_be_in_graph( + keep_input_mutations=True, + mutates_data=True, + mutates_metadata=False, + mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd( + f_inpt + ), + mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode( + f_inpt + ), + mutates_storage_metadata=False, + mutation_inductor_storage_resize=was_inductor_storage_resized( + f_inpt + ), + requires_grad=f_inpt.requires_grad, + ) + assert can_be_in_graph, ( + "a backward input that had data mutated in an autograd-aware way. This is not supported" + ) + # Perform the input mutation + with torch.fx.traceback.preserve_node_meta(): + # before and after should be tensors if we're calling copy_ on them + assert isinstance(before, torch.Tensor) and isinstance( + after, torch.Tensor + ) + before.copy_(after) + + if aot_config.keep_inference_input_mutations: + # Note: This is a bit annoying. There's a layering issue here, where: + # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. + # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. + # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, + # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). + # This makes it pretty difficult for this logic to operate on synthetic bases. + # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual + # (unpacked) input aliases, instead of the synthetic base. + # Example case where (3) could be important: + # + # def f(x, y): + # x.mul_(2) + # y.mul_(3) + # return x, y + # a = torch.ones(1'000'000) + # x, y = out(a[0:9], a[1:10]) + # + # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing + # a giant "updated synthetic base" and copying into a's entire storage. + # + # For now, we are pessimistically not performing the optimization from (3); + # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. + # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry + # about synthetic bases. + + # Apply in graph forward mutations only in joint case. + # Note: Mutations of primals in forward AND backward. + # If we have mutations of the same input in forward and in backward, + # we can not fuse them into one copy_ node. As in this case partitioner will put it + # either in forward or in backward. This will lead to incorrect state + # after forward and before backward. + # We have to emit two copy_ nodes, marking with additional meta each node, + # if it must be in forward or backward. + # We memorize mutation counter of the inputs after forward. + # Based on this after joint graph we check if backward also mutated input or not. + # We emit copy_ only in the end of joint tracing, to provide invariant for joint + # graph passes, that our graph is functional, except only some number of copy_ nodes + # in the end. + mcs_applied: list[MutationCounters] = [MutationCounters(0, 0, 0)] * len( + meta.input_info + ) + if f_args_mutation_counters_after_forward is not None: + primals_before = args[0] + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip( + f_args_after_forward, # type: ignore[arg-type] + primals_before, # type: ignore[arg-type] + primals_after_forward, # type: ignore[arg-type] + meta.input_info, + ) + ): + if inpt_info.mutation_type != MutationType.MUTATED_IN_GRAPH: + continue + + mcs_after_forward = f_args_mutation_counters_after_forward[idx] + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_forward(), + _proxy_tensor_disable_update_tensor_tracker(), + ): + apply_in_graph_mutations( + inpt_info, + before, + after, + f_inpt, + idx, + mcs_after_forward, + mcs_applied[idx], + ) + mcs_applied[idx] = mcs_after_forward + + for idx, (inpt_old, f_inpt) in enumerate( + zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) # type: ignore[arg-type] + ): + if not isinstance(f_inpt, torch.Tensor): + continue + assert is_fun(f_inpt) + inpt_new = from_fun(f_inpt) + if ( + meta.input_info[idx].mutation_type + != MutationType.MUTATED_IN_GRAPH + ): + continue + mcs: Optional[MutationCounters] = None + if f_args_mutation_counters_after_forward is not None: + # This could happen for subclasses tracing + # Subclasses support for mutations in fw and bw is TBD. + mcs = _get_mutation_counters(f_inpt) + if mcs == mcs_applied[idx]: + # No mutation in backward; mutation was already applied. + continue + + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_backward(), + ): + apply_in_graph_mutations( + meta.input_info[idx], + inpt_old, + inpt_new, + f_inpt, + idx, + mcs, + mcs_applied[idx], + ) + + # When an output tensor is a functionalized mutated input, and we + # were able to move the mutation in to the graph then we can return + # the mutated input directly. This prevents duplicating the + # tensors contents. + flat_outs, outs_spec = pytree.tree_flatten(f_outs) + flat_outs = [from_fun(o) for o in flat_outs] + num_outs = len(meta.output_info) + + for i in range(num_outs): + info = meta.output_info[i] + if info.output_type != OutputType.is_input: + continue + + assert info.base_idx is not None + if ( + meta.input_info[info.base_idx].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + fw_args = args[0] if trace_joint else args + flat_outs[i] = fw_args[info.base_idx] + return pytree.tree_unflatten(flat_outs, outs_spec), f_outs_descs + + return pytree.tree_map(from_fun, f_outs), f_outs_descs + + # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" + # and "tangents" as its input names (which are special-cased by the partitioner) + # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export + def joint_helper(primals, tangents): + return _functionalized_f_helper(primals, tangents) + + helper = joint_helper if trace_joint else _functionalized_f_helper + if config.functionalize_rng_ops: + # Setup the wrapper for functionalization of rng ops + helper, args, args_descs = create_functionalized_rng_ops_wrapper( + helper, args, args_descs, trace_joint + ) + + return helper, args, args_descs + + +def handle_effect_tokens_fn( + fn, + args, + args_descs: list[AOTInput], + *, + meta: ViewAndMutationMeta, + trace_joint: bool, +) -> Any: + num_tokens = len(meta.tokens) + + @simple_wraps(fn) + def inner_fn(*args): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert isinstance(args, tuple) and isinstance(args[0], (list, tuple)) + tokens = args[0][:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = (args[0][num_tokens:], *args[1:]) + else: + tokens = args[:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = args[num_tokens:] + + # Populate the current FunctionalTensorMode with the tokens per + # operator. See Note [FunctionalTensorMode is Stateful] + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + assert functional_tensor_mode is not None + f_tokens = pytree.tree_map(to_fun, tokens) + for i, k in enumerate(meta.tokens.keys()): + functional_tensor_mode._tokens[k] = f_tokens[i] + + # Run the joint + outs, outs_descs = call_and_expect_output_descs(fn, args) + + # Return both the tokens and the outputs + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert len(outs) == 2 + assert len(functional_tensor_mode._tokens_forward_output) == num_tokens + fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values() + + bwd_out_tokens = functional_tensor_mode._tokens.values() + + f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens] + f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens] + f_fwd_out_tokens_descs = [ + ForwardTokenAOTOutput(i) for i in range(len(fwd_out_tokens)) + ] + f_bwd_out_tokens_descs = [ + BackwardTokenAOTOutput(i) for i in range(len(bwd_out_tokens)) + ] + + meta.num_backward_tokens = len(bwd_out_tokens) + return ( + ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens)), + ( + (*f_fwd_out_tokens_descs, *outs_descs[0]), + (*outs_descs[1], *f_bwd_out_tokens_descs), + ), + ) + + out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()] + # TODO: can probably do a little more resolution here + out_tokens_descs = [ + ForwardTokenAOTOutput(i) + for i in range(len(functional_tensor_mode._tokens.values())) + ] + return ((*out_tokens, *outs), (*out_tokens_descs, *outs_descs)) + + # Additionally pass in tokens as inputs + # See Note [Side-Effectful Tokens in AOTAutograd] + additional_fwd_token_inputs = [torch.tensor([])] * num_tokens + additional_fwd_token_inputs_descs = [ + ForwardTokenAOTInput(i) for i in range(num_tokens) + ] + + if trace_joint: + args = ([*additional_fwd_token_inputs, *args[0]], *args[1:]) + args_descs = ( # type: ignore[assignment] + [*additional_fwd_token_inputs_descs, *args_descs[0]], # type: ignore[misc] + *args_descs[1:], + ) + else: + args = [*additional_fwd_token_inputs, *args] + args_descs = [*additional_fwd_token_inputs_descs, *args_descs] + return inner_fn, args, args_descs + + +# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor +# Also returns: +# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated) +# - the updated ViewAndMutationMeta for this dense -> dense function. +# The other important arguments are: +# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function. +# when is_joint_structure=False, this is just the forward function. +# - fw_only: this is *always* the forward-only function. +# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions. +# In particular, we need this to tell the partitioner how many dense forward outputs there are. +def aot_dispatch_subclass( + flat_fn_maybe_joint: Union[JointTraceFn, TraceFn], + args: Union[list[FxValue], tuple[list[FxValue], list[FxValue]]], + args_descs: Union[list[AOTInput], tuple[list[AOTInput], list[AOTInput]]], + *, + is_joint_structure: bool, + meta: ViewAndMutationMeta, + fw_only: Callable, +) -> SubclassTracingInfo: + # Skip logic if we don't need to trace through any subclasses + req_subclass_dispatch = requires_subclass_dispatch(args, meta) + if not req_subclass_dispatch: + return SubclassTracingInfo( + plain_tensor_trace_fn=flat_fn_maybe_joint, + plain_tensor_args=args, + plain_tensor_args_descs=args_descs, + maybe_subclass_meta=None, + ) + + # TODO: add subclass guards (later PR). + + # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs). + # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint, + # so we set it later, while we're tracing the joint (see inner_fn() below). + # Another option would be to run our run_functionalized_fw_and_collect_metadata() function + # directly on the joint, but this would hurt compile time (adding yet another pass through the joint). + subclass_meta = SubclassMeta() + + # NB: doesn't take descs, this is going from the NEW flat_args to the + # subclasses, we don't need to do bookkeeping here + def inner_fn(fn, args, *, use_trace_joint: bool): + # Step 1: wrap tensor inputs into subclasses if necessary + all_args = wrap_tensor_subclasses_maybe_joint( + args, is_joint_structure=use_trace_joint, meta=meta + ) + + # Step 2: call the inner function, with our (maybe subclass) inputs + wrapped_outs, wrapped_outs_descs = call_and_expect_output_descs(fn, all_args) + + if use_trace_joint: + # See Note: [Computing Subclass Metadata about grad_inputs] + # We also stash subclass info on our grad_inputs, if we're tracing the joint. + nonlocal subclass_meta + assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2, ( + wrapped_outs, + wrapped_outs_descs, + ) + # Don't need fw outs since we already have subclass metadata on them + grad_inputs = wrapped_outs[1] + subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs) + + # Add extra symints as outputs to the forward/backward graphs + # ignore nested ints here + forward_outs, forward_outs_descs = unwrap_tensor_subclasses( + wrapped_outs[0], wrapped_outs_descs[0], append_symints=True + ) + # ignore nested ints here + backward_outs, backward_outs_descs = unwrap_tensor_subclasses( + wrapped_outs[1], wrapped_outs_descs[1], append_symints=True + ) + return ( + (forward_outs, backward_outs), + (forward_outs_descs, backward_outs_descs), + ) + + # Step 3: Unwrap any subclass outputs back into dense tensors + return unwrap_tensor_subclasses( + wrapped_outs, wrapped_outs_descs, append_symints=True + ) + + def joint_fn( + primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[FxValue]], tuple[list[AOTOutput], list[AOTOutput]] + ]: + with maybe_enable_thunkify(): + return inner_fn( + flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True + ) + + def fw_fn(*primals: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: + with maybe_enable_thunkify(): + return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) + + def metadata_fn(*primals: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: + @simple_wraps(fw_only) + def inner_fw_only(*args): + return call_and_expect_output_descs(fw_only, args) + + return inner_fn(inner_fw_only, primals, use_trace_joint=False) + + if is_joint_structure: + # Add extra symints (size/strides) as input to the forward graph + primals_unwrapped_pair = unwrap_tensor_subclasses( + args[0], # type: ignore[arg-type] + args_descs[0], # type: ignore[arg-type] + append_symints=True, + ) + # We pass append_symints=False here because the partitioner will + # capture and add any extra argument + tangents_unwrapped_pair = unwrap_tensor_subclasses( + args[1], # type: ignore[arg-type] + args_descs[1], # type: ignore[arg-type] + append_symints=False, + ) + + args_unwrapped = (primals_unwrapped_pair[0], tangents_unwrapped_pair[0]) + args_descs_unwrapped = (primals_unwrapped_pair[1], tangents_unwrapped_pair[1]) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args[0], meta.static_input_indices + ) + else: + args_unwrapped, args_descs_unwrapped = unwrap_tensor_subclasses( # type: ignore[assignment] + args, # type: ignore[arg-type] + args_descs, # type: ignore[arg-type] + append_symints=True, + ) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args, meta.static_input_indices + ) + + if is_joint_structure: + primals_unwrapped = args_unwrapped[0] # type: ignore[assignment] + primals_unwrapped_descs = args_descs_unwrapped[0] # type: ignore[assignment] + fn_to_trace = joint_fn # type: ignore[assignment] + else: + primals_unwrapped = args_unwrapped # type: ignore[assignment] + primals_unwrapped_descs = args_descs_unwrapped # type: ignore[assignment] + fn_to_trace = fw_fn # type: ignore[assignment] + + # Note: [Partitioner handling for Subclasses, Part 1] + # The way the partitioner works is that: + # (1) we pass is a single graph containing the joint fw/bw, + # where the # of graph outputs corresponds to # fw_outputs + # grad_inputs + # (2) The partitioner accepts an arguments, num_fwd_outputs, + # and assumes that the first "num_fwd_outputs" graph outputs correspond + # to outputs of the forward graph. + # How do tensor subclasses enter the picture? + # the num_fwd_outputs in the final graph is actually non-trivial to compute, + # because it can be influenced by input mutations and intermediate bases. + # So we compute it by inspecting the current ViewAndMutationMeta object. + # However, the original ViewAndMutationMeta that we computed was created + # on the subclass -> subclass graph, + # which can have a different number of outputs than the dense -> dense graph. + # That's why we created a fresh metadata object on the dense -> dense function here, + # and plumb it back up to the partitioner. + # See Note: [Partitioner handling for Subclasses, Part 2] for more info. + meta_updated = run_functionalized_fw_and_collect_metadata( + without_output_descs(metadata_fn), + # pyrefly: ignore [bad-argument-type] + flat_args_descs=primals_unwrapped_descs, + static_input_indices=remapped_static_indices, + keep_input_mutations=meta.keep_input_mutations, + is_train=meta.is_train, + # pyrefly: ignore [not-iterable] + )(*primals_unwrapped) + + subclass_meta.fw_metadata = meta_updated + + return SubclassTracingInfo( + plain_tensor_trace_fn=fn_to_trace, + plain_tensor_args=args_unwrapped, + plain_tensor_args_descs=args_descs_unwrapped, + maybe_subclass_meta=subclass_meta, + ) + + +def create_functional_call( + mod, params_spec, params_len, store_orig_mod=False, strict_out_tuple=True +): + # Redundant with dynamo, but worth having in case this gets invoked elsewhere. + # https://github.com/pytorch/pytorch/issues/103569 + + @simple_wraps(mod) + def functional_call(*args, **kwargs): + flat_params = args[:params_len] + if isinstance(params_spec, TreeSpec): + params = pytree.tree_unflatten(flat_params, params_spec) + else: + assert isinstance(params_spec, list) + params = dict(zip(params_spec, flat_params)) + with ( + stateless._reparametrize_module(mod, params), + maybe_disable_thunkify(), + ): + if isinstance(mod, torch.fx.GraphModule): + if kwargs: + # Handle **kwargs. FX only natively supports positional + # arguments (through placeholders). + arg_list = list(args[params_len:]) + arg_list.extend(list(kwargs.values())) + args = tuple(arg_list) + else: + args = args[params_len:] + + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Anomaly Detection has been enabled." + ) + with torch.autograd.detect_anomaly(check_nan=False): + fake_mode = detect_fake_mode() + assert fake_mode is not None + fake_mode.epoch += 1 + out = PropagateUnbackedSymInts(mod).run(*args) + else: + out = mod(*args[params_len:], **kwargs) + + if strict_out_tuple and not isinstance(out, (tuple, list)): + raise RuntimeError( + "Graph output must be a (). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs or use aot_module instead." + ) + return out + + # Note [Preserving the nn module stack metadata during export non-strict mode] + # This path is currently only used by the non-strict export flow, + # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph. + # Instead, we stash the original user nn module here, and rely on `make_fx` to grab + # this stashed module and use it to track nn module stack metadata + if store_orig_mod and not hasattr(functional_call, "_orig_mod"): + functional_call._orig_mod = mod # type: ignore[attr-defined] + + return functional_call diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/indexed_dict.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/indexed_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..39a06996c6e08f1f3ac519e549f5012ffa8728eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/indexed_dict.py @@ -0,0 +1,54 @@ +from collections.abc import Iterator, MutableMapping +from typing import Generic, Optional, TypeVar + + +K = TypeVar("K") +V = TypeVar("V") + + +# Used for fast next key access (using the fact that the dict is ordered) +# Note: doesn't support deletion but we don't need it! +class IndexedDict(MutableMapping[K, V], Generic[K, V]): + """A dict that maintains insertion order with O(1) index access.""" + + __slots__ = ("_dict", "_keys", "_key_to_index") + + def __init__(self) -> None: + self._dict: dict[K, V] = {} + self._keys: list[K] = [] # typing: ignore[bad-override] + self._key_to_index: dict[K, int] = {} + + def __setitem__(self, key: K, value: V) -> None: + if key not in self._dict: + self._key_to_index[key] = len(self._keys) + self._keys.append(key) + self._dict[key] = value + + def __getitem__(self, key: K) -> V: + return self._dict[key] + + def __delitem__(self, key: K) -> None: + raise NotImplementedError("Deletion not supported for IndexedDict") + + def __len__(self) -> int: + return len(self._dict) + + def __iter__(self) -> Iterator[K]: + return iter(self._keys) + + def __contains__(self, key: object) -> bool: + return key in self._dict + + def next_key(self, key: K) -> Optional[K]: + """Get the next key in insertion order. O(1).""" + idx = self._key_to_index.get(key) + if idx is not None and idx + 1 < len(self._keys): + return self._keys[idx + 1] + return None + + def prev_key(self, key: K) -> Optional[K]: + """Get the previous key in insertion order. O(1).""" + idx = self._key_to_index.get(key) + if idx is not None and idx > 0: + return self._keys[idx - 1] + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..06581e1524fdef15475d9e9fc907b40ec858ad4b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -0,0 +1,466 @@ +# mypy: allow-untyped-defs +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the following analyses are provided: +1. Refine the view and mutation metadata collected previously - removing duplicate + inputs or mapping views to their bases. +2. We also analyze the function signature for export graphs. +""" + +import contextlib +import itertools +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C._dynamo.guards import compute_overlapping_tensors +from torch._functorch._aot_autograd.schemas import PlainTensorMeta +from torch._guards import StorageOverlap +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental.symbolic_shapes import is_concrete_int + +from .collect_metadata_analysis import coerce_tangent_and_suggest_memory_format +from .descriptors import AOTInput, InputMutationAOTOutput, TangentAOTInput +from .schemas import ( + BackwardSignature, + GraphSignature, + InputAliasInfo, + MemoryFormatMeta, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .utils import strict_zip + + +zip = strict_zip + + +def remove_dupe_metadata( + m: ViewAndMutationMeta, + keep_arg_mask: list[bool], + add_dupe_map: list[int], +) -> ViewAndMutationMeta: + assert len(m.input_info) == len(keep_arg_mask) + # Easy invariant: the first argument should never be a dupe (it will be kept) + assert len(keep_arg_mask) > 0 and keep_arg_mask[0] + + # Filter dupe'd mutated inputs out of traced_tangents + num_data_mutations = len([x for x in m.input_info if x.mutates_data]) + other_traced_tangents = m.traced_tangents[num_data_mutations:] + inp_traced_tangents = m.traced_tangents[:num_data_mutations] + other_traced_tangents_descs = m.traced_tangents_descs[num_data_mutations:] + inp_traced_tangents_descs = m.traced_tangents_descs[:num_data_mutations] + filtered_inp_traced_tangents = [ + # See Note [Tangents memory format] + x + for i, x in enumerate(inp_traced_tangents) + if keep_arg_mask[m.mutated_inp_runtime_indices[i]] + ] + filtered_inp_traced_tangents_descs = [ + x_desc + for i, x_desc in enumerate(inp_traced_tangents_descs) + if keep_arg_mask[m.mutated_inp_runtime_indices[i]] + ] + traced_tangents = filtered_inp_traced_tangents + other_traced_tangents + traced_tangents_descs = ( + filtered_inp_traced_tangents_descs + other_traced_tangents_descs + ) + + assert m.subclass_tangent_meta is not None + subclass_tangent_meta = [ + PlainTensorMeta( + 0, memory_format=MemoryFormatMeta(memory_format=torch.contiguous_format) + ) + ] * len(filtered_inp_traced_tangents) + m.subclass_tangent_meta[num_data_mutations:] + + return ViewAndMutationMeta( + input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]], + # For outputs that are views of inputs, we store the index of the input that the output + # was generated from. Need to update that index to account for removed dupes. + output_info=[ + OutputAliasInfo( + output_type=o.output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], + requires_grad=o.requires_grad, + view_meta_sequence=o.view_meta_sequence, + ) + for o in m.output_info + ], + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + traced_tangents_descs=traced_tangents_descs, + # We are guaranteed not to get here, since dupes are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=subclass_tangent_meta, + is_train=m.is_train, + ) + + +# Given our ViewAndMutation metadata, this fn constructs a new set of metadata, +# after adding synthetic base arguments to the function. +# Most of the work in this fn is slogging through all of the metadata corresponding to inputs, +# and updating it with our synthetic base calling convention. +# +# When config.debug_assert is set, we automatically regenerate the metadata +# and compare it to this output for sanity. +# +# In addition to the updated metadata, also return the list of input indices +# that will need to be updated in the synthetic base epilogue +def create_synthetic_base_metadata( + m: ViewAndMutationMeta, + # Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a + # synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata) + synthetic_base_info: list[Union[int, tuple[int, torch.Tensor]]], + outer_args: list[Any], + inner_args: list[Any], + inner_args_desc: list[AOTInput], +) -> tuple[ViewAndMutationMeta, list[int]]: + # maps inner arg indices to outer arg indices + synthetic_base_to_indices: dict[int, list[int]] = {} + for inner_idx in range(len(inner_args)): + outer_aliased_indices_of_current_base_arg = [ + outer_idx + for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info) + if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx) + or ( + isinstance(inner_idx_or_tuple, tuple) + and inner_idx_or_tuple[0] == inner_idx + ) + ] + synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg + + # given the requires_grad info on mutated inputs, + # generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases. + input_infos = [] + for outer_indices in synthetic_base_to_indices.values(): + # leaf-ness should be all-or-nothing for aliased tensor. + # (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf) + any_leaf = any(m.input_info[x].is_leaf for x in outer_indices) + all_leaf = all(m.input_info[x].is_leaf for x in outer_indices) + assert any_leaf == all_leaf + + mutates_data = ( + True + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_data + ) + mutates_metadata = ( + False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_metadata + ) + requires_grad = any(m.input_info[x].requires_grad for x in outer_indices) + mutations_under_no_grad_or_inference_mode = all( + m.input_info[x].mutations_under_no_grad_or_inference_mode + for x in outer_indices + ) + + mutation_inductor_storage_resize = all( + m.input_info[x].mutation_inductor_storage_resize for x in outer_indices + ) + + inpt_info = InputAliasInfo( + # If len(outer_indices) > 1, then this input is a synthetic base. + # The invariant is that to the rest of aot autograd, synthetic bases only show up if + # one of their aliases gets a data mutation. And if any of their aliases get metadata + # mutations, they will be hidden from the rest of aot autograd. + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=all( + m.input_info[x].mutations_hidden_from_autograd for x in outer_indices + ), + mutates_storage_metadata=( + False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_storage_metadata + ), + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + mutation_inductor_storage_resize=mutation_inductor_storage_resize, + is_leaf=any_leaf, + requires_grad=requires_grad, + keep_input_mutations=m.keep_input_mutations, + ) + input_infos.append(inpt_info) + + # Find any inputs that fulfill the following criteria: + # (1) They are part of a synthetic base (because they alias another input, + # and at least one input experiences a data mutation) + # (2) They experience a metadata mutation + outer_aliased_arg_idx_with_metadata_mutations = [ + outer_idx + for outer_idx, inpt_info in enumerate(m.input_info) + if inpt_info.mutates_metadata + and not isinstance(synthetic_base_info[outer_idx], int) + ] + + # grab the original requires grad info on the outputs, except the ones from the mutated inputs + input_metadata_output_info = [ + OutputAliasInfo( + output_type=OutputType.alias_of_input, + raw_type=FunctionalTensor, + dynamic_dims={ + i + for i, s in enumerate(outer_args[outer_idx].shape) + if not is_concrete_int(s) + }, + base_idx=synthetic_base_info[outer_idx][0], # type: ignore[index] + requires_grad=outer_args[outer_idx].requires_grad, + ) + for outer_idx in outer_aliased_arg_idx_with_metadata_mutations + ] + existing_output_infos = [] + for o in m.output_info: + new_base_idx = ( + None + if o.base_idx is None + else ( + synthetic_base_info[o.base_idx] + if isinstance(synthetic_base_info[o.base_idx], int) + else synthetic_base_info[o.base_idx][0] # type: ignore[index] + ) + ) + # If base_idx is changed for OutputType.is_input, we need to update the output type to reflect the change + new_output_type = ( + OutputType.alias_of_input + if o.output_type == OutputType.is_input and o.base_idx != new_base_idx + else o.output_type + ) + existing_output_infos.append( + OutputAliasInfo( + output_type=new_output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases + base_idx=new_base_idx, # type: ignore[arg-type] + requires_grad=o.requires_grad, + view_meta_sequence=o.view_meta_sequence, + ) + ) + + inner_mutated_tangents_and_memory_formats = [ + # See Note [Tangents memory format] + ( + coerce_tangent_and_suggest_memory_format(x), + TangentAOTInput(InputMutationAOTOutput(x_desc)), + ) + for inner_idx, (x, x_desc) in enumerate(zip(inner_args, inner_args_desc)) + if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad + ] + inner_mutated_tangents = [ + x[0][0] for x in inner_mutated_tangents_and_memory_formats + ] + inner_mutated_tangents_descs = [ + x[1] for x in inner_mutated_tangents_and_memory_formats + ] + inner_mutated_tangents_memory_formats = [ + x[0][1] for x in inner_mutated_tangents_and_memory_formats + ] + + output_info = existing_output_infos + input_metadata_output_info + # Regenerate traced tangents to include mutated inputs including synthetic bases + traced_tangents = ( + inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :] + ) + traced_tangents_descs = ( + inner_mutated_tangents_descs + + m.traced_tangents_descs[len(inner_mutated_tangents) :] + ) + assert m.subclass_tangent_meta is not None + subclass_tangent_meta = [ + PlainTensorMeta(0, memory_format=x) + for x in inner_mutated_tangents_memory_formats + ] + m.subclass_tangent_meta[len(inner_mutated_tangents) :] + + return ( + ViewAndMutationMeta( + input_info=input_infos, + output_info=output_info, + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + traced_tangents_descs=traced_tangents_descs, + # We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=subclass_tangent_meta, + is_train=m.is_train, + ), + outer_aliased_arg_idx_with_metadata_mutations, + ) + + +def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices): + num_aliases = len(aliased_input_indices) + + shape_env = None + maybe_suppress_guards = contextlib.nullcontext + tracing_context = torch._guards.TracingContext.try_get() + + if tracing_context is not None: + assert tracing_context.fake_mode is not None + shape_env = tracing_context.fake_mode.shape_env + + # Check whether we can actually get the dynamo sources from within AOTAutograd. + if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None: + maybe_suppress_guards = shape_env.suppress_guards # type: ignore[assignment] + + # Check whether there are any symbolic values being used. + # We do this for 2 reasons: + # 1. StorageOverlap guard is only issued whenever dynamic shapes is turned on + # 2. Triggers the fast-path for computing storage overlapping + symbolic = any( + isinstance(x, torch.SymInt) + for i in aliased_input_indices + for x in [ + *fwd_inputs[i].shape, + *fwd_inputs[i].stride(), + fwd_inputs[i].storage_offset(), + ] + ) + + if torch._inductor.config.is_fbcode(): + if symbolic and num_aliases > 400: + from torch._subclasses.fake_tensor import ( + UnsupportedMutationAliasingException, + ) + from torch._utils_internal import justknobs_check + + msg = f"Encountered {num_aliases} dynamic, aliased/mutated inputs, consider setting dynamic=False" + + if justknobs_check( + "pytorch/compiler:aliased_inputs_with_mutation_and_dyn_shapes_killswitch", + False, + ): + raise UnsupportedMutationAliasingException(msg) + + with maybe_suppress_guards(): + aliased_fwd_inputs = [fwd_inputs[i] for i in aliased_input_indices] + actual_aliased_indices = { + aliased_input_indices[i] + for i in compute_overlapping_tensors(aliased_fwd_inputs, symbolic=symbolic) + } + + # Add the StorageOverlap AOTAutograd guard only if we are actually keeping track of + # dynamo sources inside AOTAutograd. + if ( + tracing_context is not None + # Make sure dynamic shapes is currently being used. + and symbolic + # We check that we have more than 1 aliased tensor, which should be true at + # this point, anyway. + and num_aliases > 1 + and aot_config.aot_autograd_arg_pos_to_source + ): + no_overlap_indices = list(set(aliased_input_indices) - actual_aliased_indices) + + overlapping_sources = [ + aot_config.aot_autograd_arg_pos_to_source[i] for i in actual_aliased_indices + ] + non_overlapping_sources = [ + aot_config.aot_autograd_arg_pos_to_source[i] for i in no_overlap_indices + ] + + tracing_context.guards_context.aotautograd_guards.append( + StorageOverlap(overlapping_sources, non_overlapping_sources) + ) + + return actual_aliased_indices + + +def _graph_input_names(gm): + return [node.name for node in gm.graph.find_nodes(op="placeholder")] + + +def _graph_output_names(gm): + output_node = next(iter(reversed(gm.graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + return [getattr(return_arg, "name", None) for return_arg in return_args] + + +def create_graph_signature( + fx_g: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + *, + user_args_flat: list[Tensor], + params_and_buffers_flat: list[Tensor], + param_names: list[str], + buffer_names: list[str], + trace_joint: bool, + num_user_fw_outs: Optional[int], + loss_index: Optional[int], +) -> GraphSignature: + # Retrieve graph input names + graph_input_names = _graph_input_names(fx_g) + # Retrieve graph output names + graph_output_names = _graph_output_names(fx_g) + + num_params_buffers = len(param_names) + len(buffer_names) + num_tokens = len(fw_metadata.tokens) + # We have enough restrictions on the graph (no de-duping, synthetic bases, etc), + # Such that # graph inps = # user inps + # params + # buffers + num_user_args = len(graph_input_names) - num_params_buffers - num_tokens + + if trace_joint: + assert num_user_fw_outs is not None + num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices + backward_output_names = graph_output_names[num_fw_outs:] + + grad_index = itertools.count(0) + gradients_to_parameters = { + backward_output_names[next(grad_index)]: param_names[i] + for i, param in enumerate(params_and_buffers_flat) + if param.requires_grad + } + + gradients_to_user_inputs = { + backward_output_names[next(grad_index)]: graph_input_names[ + i + len(params_and_buffers_flat) + ] + for i, user_input in enumerate(user_args_flat) + if user_input.requires_grad + } + + assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len( + backward_output_names + ) + + # Check that we have fully accounted for all graph outputs + backward_signature = BackwardSignature( + gradients_to_parameters, + gradients_to_user_inputs, + graph_output_names[loss_index], + ) + else: + backward_signature = None + num_user_fw_outs = ( + len(graph_output_names) + - fw_metadata.num_mutated_inp_runtime_indices + - num_tokens + ) + + return GraphSignature.from_tracing_metadata( + in_spec=in_spec, + out_spec=out_spec, + graph_input_names=graph_input_names, + graph_output_names=graph_output_names, + view_mutation_metadata=fw_metadata, + named_parameters=param_names, + named_buffers=buffer_names, + num_user_inputs=num_user_args, + num_user_outputs=num_user_fw_outs, + trace_joint=trace_joint, + loss_index=loss_index, + backward_signature=backward_signature, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..86202e2cd319d9a959d1af9e57efca9299624085 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -0,0 +1,2604 @@ +# mypy: allow-untyped-defs +""" +This module defines runtime wrappers, which, based on previous analysis attempts to: +1. process the inputs and outputs +2. apply mutations +3. handle functionalized randomness +4. deduplicate inputs and consolidate views into their bases (see input_output_analysis) +""" + +import builtins +import collections +import contextlib +import copy +import functools +import itertools +import pprint +from collections.abc import Callable +from contextlib import AbstractContextManager, nullcontext +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Optional, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from collections.abc import Sequence + +import torch +import torch.fx as fx +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo import config as dynamo_config +from torch._dynamo.callback import callback_handler, CallbackTrigger +from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context +from torch._guards import ( + compile_context, + CompileContext, + detect_fake_mode, + DuplicateInputs, + tracing, + TracingContext, +) +from torch._prims_common import CUDARngStateHelper +from torch._subclasses import FakeTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .descriptors import ( + AOTInput, + AOTOutput, + DummyAOTInput, + MetadataMutationAOTOutput, + SyntheticBaseAOTInput, + ViewBaseAOTInput, +) +from .functional_utils import gen_alias_from_base +from .graph_capture_wrappers import aot_dispatch_subclass +from .input_output_analysis import ( + compute_overlapping_inputs, + create_synthetic_base_metadata, + remove_dupe_metadata, +) +from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling +from .schemas import ( + AOTConfig, + CompilerWrapper, + FxValue, + InductorWrapper, + InputAliasInfo, + MemoryFormatMeta, + MutationType, + OutputType, + PlainTensorMeta, + SubclassCreationMeta, + SubclassMeta, + TensorAlias, + TraceFn, + ViewAndMutationMeta, +) +from .subclass_utils import ( + requires_subclass_dispatch, + runtime_unwrap_tensor_subclasses, + wrap_tensor_subclasses, +) +from .utils import ( + call_and_expect_output_descs, + call_func_at_runtime_with_args, + make_boxed_func, + partial_flatten_asdict, + simple_wraps, + strict_zip, + without_output_descs, +) + + +zip = strict_zip + + +# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic +# that needs to run after the compiled function. +# +# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime +# epilogue for a forward-only inference graph, or for an autograd.Function.apply function. +# This is because there are some minor differences in how we treat these cases at runtime: +# - resize_() is currently handled in the inference case, but not fully handled in the autograd case. +# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs +@dataclass +class RuntimeWrapper(CompilerWrapper): + indices_of_inps_to_detach: list[int] + trace_joint: bool + disable_amp: bool + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + return _create_runtime_wrapper( + compiled_fn, + runtime_metadata=runtime_metadata, + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=self.trace_joint, + keep_input_mutations=aot_config.keep_inference_input_mutations, + disable_amp=self.disable_amp, + ) + + +class NoopAliasHandler: + def __init__(self, info, runtime_metadata, trace_joint): + pass + + def __call__(self, orig_inputs, fw_outs, out): + return out + + +def _unwrap_tensoralias(x): + assert isinstance(x, TensorAlias) + return x.alias + + +def _identity(x): + return x + + +class AliasOfInputHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self.base_idx = info.base_idx + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + self.requires_grad = info.requires_grad + self.view_meta_sequence = info.view_meta_sequence + self.replay_views = config.view_replay_for_aliased_outputs + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = orig_inputs[self.base_idx] + return gen_alias_from_base( + aliased_base_tensor, + self.unwrap_out(out), + self.requires_grad, + self.view_meta_sequence, + replay_views=self.replay_views, + ) + + +class IsInputHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self.base_idx = info.base_idx + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = orig_inputs[self.base_idx] + return aliased_base_tensor + + +class AliasOfIntermediateHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self._unwrap_aliased_base_tensor = _identity + if info.output_type in ( + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + ): + num_user_outputs = len(runtime_metadata.output_info) + self.base_idx = info.base_idx + num_user_outputs + else: + self.base_idx = info.base_idx + if self.base_idx in runtime_metadata.aliased_out_indices: + self._unwrap_aliased_base_tensor = _unwrap_tensoralias + + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + self.requires_grad = info.requires_grad + self.view_meta_sequence = info.view_meta_sequence + self.replay_views = config.view_replay_for_aliased_outputs + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = fw_outs[self.base_idx] + return gen_alias_from_base( + self._unwrap_aliased_base_tensor(aliased_base_tensor), + self.unwrap_out(out), + self.requires_grad, + self.view_meta_sequence, + replay_views=self.replay_views, + ) + + +_HANDLER_MAP = { + OutputType.non_alias: NoopAliasHandler, + OutputType.unsafe_view_alias: NoopAliasHandler, + OutputType.custom_function_view: NoopAliasHandler, + OutputType.alias_of_input: AliasOfInputHandler, + OutputType.is_input: IsInputHandler, + OutputType.alias_of_intermediate: AliasOfIntermediateHandler, + OutputType.alias_of_intermediate_save_as_output: AliasOfIntermediateHandler, + OutputType.alias_of_intermediate_base_is_user_output: AliasOfIntermediateHandler, +} + + +def make_output_handler(info, runtime_metadata, trace_joint): + handler_type = _HANDLER_MAP[info.output_type] + return handler_type(info, runtime_metadata, trace_joint) + + +# not sure why AOTDispatcher needs to manually set this +def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]): + if hasattr(t, "_dynamo_weak_dynamic_indices"): + # pyrefly: ignore [missing-attribute] + t._dynamo_weak_dynamic_indices |= dims + else: + t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined] + + +def _should_disable_saved_tensors_hooks(): + # Compiled autograd is not supported yet, to be added in future. + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return False + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + hooks = get_hooks() + if are_inline_hooks(hooks): + return True + + return False + + +def _create_runtime_wrapper( + compiled_fn, + *, + runtime_metadata: ViewAndMutationMeta, + indices_of_inps_to_detach: list[int], + trace_joint: bool, + keep_input_mutations: bool, + disable_amp: bool, +): + if not getattr(compiled_fn, "_boxed_call", False): + compiled_fn = make_boxed_func(compiled_fn) + + # Note [Inputs needed in runtime epilogue after list clearing] + # In Python functions, you can't free the input arguments of a function within the scope of that function. A workaround is to + # wrap the input arguments in a list, and clear the list from within the function. + # Here, this is implemented as `call_func_at_runtime_with_args(..., steal_args=True)`. + # + # This is needed for Compiled Autograd since some of the inputs (activations) should be freed early. + # However, we cannot blindly clear the entire list, because AOTAutograd may need access to some of the graph inputs + # **after** the compiled function has finished running. There are two main cases: + # (1) Input mutations: If there are an input mutations that we must run outside of the graph, we need access to the input. + # (2) Output aliasing: Outputs that aliases graph inputs generally must be regenerated outside of the `autograd.Function`, + # and doing so requires us accessing the corresponding input after the compiled artifact has run. + epilogue_args_idx = [] + epilogue_args_idx.extend(runtime_metadata.mutated_inp_runtime_indices) + for info in runtime_metadata.output_info: + if ( + info.output_type == OutputType.alias_of_input + or info.output_type == OutputType.is_input + ): + assert isinstance(info.base_idx, int) + epilogue_args_idx.append(info.base_idx) + + if config.unlift_effect_tokens: + assert len(runtime_metadata.tokens) == 0 + + if runtime_metadata.num_outputs_aliased > 0: + output_handlers = tuple( + make_output_handler(info, runtime_metadata, trace_joint) + for info in runtime_metadata.output_info + ) + + def record_runtime_wrapper_prologue_enter() -> Optional[ + AbstractContextManager[None] + ]: + if ( + torch.autograd.profiler._is_profiler_enabled + and dynamo_config.record_runtime_overhead + ): + cm = torch._C._profiler._RecordFunctionFast( + "AOTDispatcher Runtime Wrapper Prologue" + ) + cm.__enter__() + return cm + return None + + def record_runtime_wrapper_prologue_exit( + cm: Optional[AbstractContextManager[None]], + ) -> None: + if cm is not None: + cm.__exit__(None, None, None) + + @simple_wraps(compiled_fn) + def runtime_wrapper(args: list[Any]): + # Create context manager for profiler + cm = record_runtime_wrapper_prologue_enter() + + # stash a ref to each input tensor we plan to use after the compiled function + orig_inputs = {i: args[i] for i in epilogue_args_idx} + + if keep_input_mutations: + mutated_args = ( + args[i] + for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd + ) + torch.autograd.graph.increment_version(mutated_args) + + if trace_joint: + args_ = list(args) + # See Note [Detaching inputs that never need gradients] + for idx in indices_of_inps_to_detach: + if isinstance(args_[idx], torch.Tensor): + args_[idx] = args_[idx].detach() + + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with ( + torch.autograd._force_original_view_tracking(True), + torch.enable_grad(), + ): + record_runtime_wrapper_prologue_exit(cm) + all_outs = call_func_at_runtime_with_args( + compiled_fn, args_, disable_amp=disable_amp, steal_args=True + ) + else: + # When we have an inference graph, we run with grad disabled. + # It's possible to get an inference graph with inputs that require grad, + # in which case we want to make sure autograd is disabled + # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) + # NOTE: We use _set_grad_enabled directly to reduce runtime overhead + grad_enabled = torch.is_grad_enabled() + try: + if grad_enabled: + torch._C._set_grad_enabled(False) + record_runtime_wrapper_prologue_exit(cm) + all_outs = call_func_at_runtime_with_args( + compiled_fn, args, disable_amp=disable_amp, steal_args=True + ) + finally: + if grad_enabled: + torch._C._set_grad_enabled(True) + del args + + num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices + num_intermediate_bases = runtime_metadata.num_intermediate_bases + + assert ( + len(all_outs) + == num_mutated_runtime_inps + + runtime_metadata.num_outputs + + num_intermediate_bases + ) + + # Step 3: After running the compiled fw, apply updates to mutated inputs + num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices + if num_mutations_to_apply > 0: + updated_inputs = all_outs[:num_mutations_to_apply] + fw_outs = all_outs[num_mutations_to_apply:] + + for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): + meta = runtime_metadata.input_info[inpt_idx] + if not meta.mutates_data and not meta.mutates_metadata: + continue + original_inpt = orig_inputs[inpt_idx] + updated_inpt = updated_inputs[i] + if meta.mutates_storage_metadata: + # See Note [set_() Input Mutations in AOTAutograd] + # mutates_storage_metadata means our input saw a x.set_(y) call. + # What if x **also** saw a data and/or a metadata mutation? + # (1) If the [meta]data mutation occurred after the set_(), + # then there is no need to copy_() the data. + # When we perform x.set_(x_updated), we are guaranteed that + # x_updated already has the final version of the data/metadata + # (2) If a data mutation occurred before the set_(). + # This case seems very difficult to support. + # TODO: discuss on the PR and decide if we want to tr to + # either support it, or detect and ban it. + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + with torch.no_grad(): + original_inpt.set_(updated_inpt) + continue + if meta.mutates_metadata and not meta.mutates_data: + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + # We need to grab the size/stride/storage_offset from the compiled forward, + # and use that to mutate the metadata of the input + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + if meta.mutates_data and meta.mutates_metadata: + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + assert meta.mutates_data + if meta.is_leaf and original_inpt.requires_grad: + # We can hit this situation in this case: + # def f(x): + # x.detach().mul_(2) + # return x + 1 + # AOTAutograd will see a mutation in the above case, and try to + # apply a copy_() here, in the epilogue. + # But if x required gradients, and is a leaf, then autograd + # will yell at us for trying to mutate it. + # However, it's only possible to end up in this scenario (like the above) + # if all of the mutations to the leaf input were non-autograd-tracking mutations + # (aka mutations under no_grad(), or on detached views). + # In that case, we fully want to hide the mutation from autograd, so detaching is ok. + original_inpt.detach().copy_(updated_inpt) + else: + original_inpt.copy_(updated_inpt) + else: + fw_outs = all_outs + + # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of + # compiling them. + if runtime_metadata.num_outputs_aliased > 0: + # The compiled forward also returned intermediate bases. We don't want to return them to the user. + expect_num_outputs = ( + len(output_handlers) + runtime_metadata.num_intermediate_bases + ) + assert len(fw_outs) == expect_num_outputs + ret_outs = [ + handler(orig_inputs, fw_outs, out) + for out, handler in builtins.zip(fw_outs, output_handlers) + ] + else: + ret_outs = fw_outs + + if runtime_metadata.dynamic_outputs: + for t, o in zip(ret_outs, runtime_metadata.output_info): + if o.dynamic_dims is None: + continue + maybe_mark_dynamic_helper(t, o.dynamic_dims) + if runtime_metadata.grad_enabled_mutation is not None: + torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) + return ret_outs + + if not (trace_joint and _should_disable_saved_tensors_hooks()): + return runtime_wrapper + + # Disabling saved tensors hooks + @simple_wraps(runtime_wrapper) + def _runtime_wrapper(*args, **kwargs): + with _disable_saved_tensors_hooks(): + return runtime_wrapper(*args, **kwargs) + + return _runtime_wrapper + + +# WARNING: this does NOT operate on TraceFn +@dataclass +class FunctionalizedRngRuntimeWrapper(InductorWrapper): + # TODO: I would love to get rid of this argument, but it's + # Wrapped pretty tightly around our aot_dispatch_autograd logic. + # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices + # for setting placeholder strides(which is done before runtime, before this wrapper runs) + # and for saving tensors for backward (which is done during runtime, after this wrapper runs) + # So in aot_dispatch_autograd, this wrapper can't edit the set of outs without making one + # of those two indices incorrect. + return_new_outs: bool = True + + def pre_compile( + self, + flat_fn: torch.fx.GraphModule, + flat_args, + aot_config, + *, + fw_metadata, + ) -> None: + if config.functionalize_rng_ops: + # Update example inputs for the fw_compiler + fake_mode = detect_fake_mode() + assert fake_mode is not None + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) + flat_args.extend([seed, offset]) + # We are not clearing flat_args here because + # 1) There is a check in the debug compiler at the end + # 2) It does not matter as these are fake tensors + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + @wraps(compiled_fn) + def wrapper(runtime_args: list[Any]): + if runtime_metadata.is_rng_op_functionalized: + # Add the seed and offset to args + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple() + runtime_args.extend([seed, offset]) + out = compiled_fn(runtime_args) + out = self._functionalized_rng_runtime_epilogue( + runtime_metadata, + out, + # TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper + runtime_metadata.num_forward_returns, + ) + return out + return compiled_fn(runtime_args) + + return wrapper + + # Calling convention: If we are running functionalized RNG, then outs consists + # of (user_outs, rng_offset) + def _functionalized_rng_runtime_epilogue( + self, + metadata: ViewAndMutationMeta, + outs, + offset_index, + ): + if metadata.is_rng_op_functionalized: + assert metadata.num_outputs_rng_offset == 1 + new_rng_offset = outs[offset_index] + CUDARngStateHelper.set_new_offset(new_rng_offset) + if self.return_new_outs: + user_outs = outs[:offset_index] + outs[offset_index + 1 :] + return user_outs + else: + return outs + + return outs + + +# WARNING: this does NOT operate on TraceFn +@dataclass +class FakifiedOutWrapper(InductorWrapper): + out_metas: list[torch.Tensor] = field(default_factory=list) + # TracingContext.fwd_output_strides + # Generated from actually doing compile + # NB: an entry is None if it's not a Tensor + fwd_output_strides: Optional[list[Optional[list[int]]]] = None + needs_post_compile: bool = True + + def pre_compile( + self, + fw_module: fx.GraphModule, # Must be fw_module from aot_dispatch_*_graph + flat_args, + aot_config, + *, + fw_metadata, + ) -> None: + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context and tracing_context.fakify_first_call: + self.out_metas = [ + n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0]) + ] + else: + self.needs_post_compile = False + + def _compute_output_meta_with_inductor_strides(self): + out = self.out_metas + fwd_output_strides = self.fwd_output_strides + if not fwd_output_strides: + return out + + from torch.fx.experimental.symbolic_shapes import statically_known_true + + for i in range(len(out)): + if not isinstance(out[i], Tensor): + continue + strides = fwd_output_strides[i] + # fwd_output_strides is best effort by Inductor. When an output + # Tensor has unbacked SymInts, Inductor may sometimes be unable + # to compute what the output stride would be. If Inductor doesn't + # have any clear direction on the layout, we don't have to run + # as_strided. To repro without this, run: + # + # python test/distributed/test_dynamo_distributed.py + # TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding + if strides is None: + continue + if all( + statically_known_true(s1 == s2) + for s1, s2 in zip(out[i].stride(), strides) + ): + continue + out[i] = out[i].as_strided(out[i].shape, strides) + return out + + # To be called post compile + def set_fwd_output_strides(self, fwd_output_strides): + self.fwd_output_strides = fwd_output_strides + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if self.needs_post_compile: + assert self.fwd_output_strides is not None + fakified_out = self._compute_output_meta_with_inductor_strides() + + @wraps(compiled_fn) + def wrapper(runtime_args): + nonlocal fakified_out + if fakified_out is not None: + out = fakified_out + fakified_out = None + return out + return compiled_fn(runtime_args) + + return wrapper + # If we don't need to fakify, we can just return the original compiled function + return compiled_fn + + +# This wrapper handles the AOTDispatch runtime logic for tensor subclasses. +# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor, +# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs). +# This function handles the wrapping and unwrapping of tensor subclasses at runtime. +@dataclass +class AOTDispatchSubclassWrapper(CompilerWrapper): + trace_joint: bool + fw_only: Optional[Callable] # Not cached, only used in pre_compile + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + def pre_compile( + self, + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ): + (new_flat_fn, new_flat_args, new_flat_args_descs, subclass_meta) = ( + aot_dispatch_subclass( + flat_fn, + flat_args, + flat_args_descs, + is_joint_structure=self.trace_joint, + meta=fw_metadata, + fw_only=self.fw_only, # type: ignore[arg-type] + ) + ) + self.maybe_subclass_meta = subclass_meta + return new_flat_fn, new_flat_args, new_flat_args_descs, fw_metadata + + def post_compile( + self, + compiled_fn, + _aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if self.maybe_subclass_meta is None: + return compiled_fn + + subclass_metas = runtime_metadata.subclass_fw_graph_out_meta + + @wraps(compiled_fn) + def inner_fn(args: list[Any]): + unwrapped_args = runtime_unwrap_tensor_subclasses( + args, + subclass_metas=runtime_metadata.subclass_inp_meta, + append_symints=True, + ) + args.clear() + # expectation: runtime_fn is a boxed fn + unwrapped_outs = compiled_fn(unwrapped_args) + wrapped_outs = wrap_tensor_subclasses( + unwrapped_outs, + subclass_metas=subclass_metas, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + is_runtime=True, + included_subclass_symints=True, + ) + return wrapped_outs + + # box it + inner_fn._boxed_call = True # type: ignore[attr-defined] + return inner_fn + + +@dataclass +class EffectTokensWrapper(CompilerWrapper): + def post_compile( + self, + compiled_fn, + _aot_config, + *, + runtime_metadata: ViewAndMutationMeta, + ): + num_tokens = len(runtime_metadata.tokens) + + @wraps(compiled_fn) + def inner_fn(args: list[Any]): + if num_tokens > 0: + # Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + old_args = args + args = [*([None] * num_tokens), *args] + old_args.clear() + + outs = compiled_fn(args) + + # Inductor cache DummyModule can return None + if outs is None: + return None + # Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + return outs[num_tokens:] if num_tokens != 0 else outs + + # box it + inner_fn._boxed_call = True # type: ignore[attr-defined] + return inner_fn + + +# MOTIVATION: +# +# When tracing functions for future execution, one must be careful not to pass +# in the same input tensor multiple times (e.g., f(x, x), as this can result +# in graphs that are ONLY valid if you later pass a new tensor in exactly the +# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct +# tensors that alias each other is a different situation that is covered by +# aot_dispatch_deduplicated_autograd). Here are two examples: +# +# (1) Suppose you have a function: +# +# def f(x, y): +# return x + y +# +# If you make_fx(f)(x, x), you will trace out: +# +# def f(x, y): +# return y + y +# +# Oops! +# +# (2) For most tensors x and y, you can compute f's gradient with respect to +# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However, +# if x is y, you will trace out a program that gets incorrect gradients: +# +# >>> x = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + x, (x, x)) +# (tensor([2.]), tensor([2.])) +# +# In other words, the gradient is double-counted. Deduplicating the arguments +# gives you an appropriate gradient: +# +# >>> y = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + y, (x, y)) +# (tensor([1.]), tensor([1.])) +# +# HOW TO DEDUPLICATE: +# +# There are a few strategies, in order of preference: +# +# 1. For every duplicate argument to the function, detach it into +# a separate leaf tensor, so that it is no longer duplicated. +# +# PRO: The resulting compiled graph works for any configuration +# of duplicated arguments. +# +# CON: It does not (naively) work if you mutate the metadata of inputs: +# +# def f(x, y): +# x.transpose_(0, 1) +# y.transpose_(0, 2) +# +# x = torch.randn(2, 3, 4) +# f(x, x) +# +# The ordering of the transposes inside f dictates whether or not +# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute +# what metadata mutations should get applied to each input; you need to +# assume they aren't duplicates (what we do today) or preserve +# the original metadata mutations exactly in order, so that they work +# for any duplicate configuration. +# +# CON: It does not (naively) work if you mutate the data of inputs. +# In particular, leaf tensors that require grad cannot be mutated, +# this makes it impossible to differentiate with respect to the original +# base. +# +# 2. For every duplicate argument to the function, remove it, so it is +# no longer part of the "true" signature: +# +# PRO: Implemented naively, it still works for metadata/data mutation. +# +# CON: The resulting compiled graph is duplicate-specialized: it only +# works if future calls duplicate arguments in exactly the same way. +# Horribly, Dynamo doesn't guard on this at the moment. But even if +# it did, you could still end up recompiling a bunch of each duplicate. +# +# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if +# Dynamo's guards are not enough. In practice, this seems to cover +# everything. +# +@dataclass +class AOTDedupeWrapper(CompilerWrapper): + keep_arg_mask: list[bool] = field(default_factory=list) + add_dupe_map: list[int] = field(default_factory=list) + old_input_metadata: list[InputAliasInfo] = field(default_factory=list) + needs_post_compile: bool = True + + # NB: Hot path, avoid set lookups here + # TODO: Can avoid the zip here too, probably + def remove_dupe_args(self, args): + return [t for t, keep in zip(args, self.keep_arg_mask) if keep] + + def add_dupe_args(self, args): + return [args[i] for i in self.add_dupe_map] + + def pre_compile( + self, + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[TraceFn, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + # Use information about whether or not flat_fn mutates its arguments + # or not to handle dupe args + + # Strategy 1: For any input that is not mutated, we can leafify it if we + # need to remove a duplicate. + leaf_flat_args: list[FxValue] = [] + leaf_flat_args_descs: list[AOTInput] = [] + args_set = set() + ok = True + + for i, (a, a_desc) in enumerate(zip(flat_args, flat_args_descs)): + if not isinstance(a, torch.Tensor): + leaf_flat_args.append(a) + leaf_flat_args_descs.append(a_desc) + elif a not in args_set: + args_set.add(a) + leaf_flat_args.append(a) + leaf_flat_args_descs.append(a_desc) + elif ( + not fw_metadata.input_info[i].mutates_data + and not fw_metadata.input_info[i].mutates_metadata + ): + leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad)) + leaf_flat_args_descs.append(a_desc) + else: + ok = False + break + + if ok: + self.needs_post_compile = False + return flat_fn, leaf_flat_args, leaf_flat_args_descs, fw_metadata + + if requires_subclass_dispatch(leaf_flat_args, fw_metadata): + raise RuntimeError( + """\ + Encountered duplicate inputs that are mutated in the graph, but at least one input/output + to the graph is a tensor subclass. This is not supported today. You can try to + remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + # export path: ban duplicate inputs for now, add later if requested. + if aot_config.is_export: + raise RuntimeError( + f"""\ + Encountered duplicated inputs that are mutated in the graph you are trying to export. + This functionality is currently not supported. If needed, please file a github issue. + + fw_metadata={str(fw_metadata)} + """ + ) + + # Strategy 2: Duplicate specialization + # + # When we have duplicate arguments in a function call, we need to handle them specially. + # For example, if we have a function call f(a, b, a, c), we need to: + # + # 1. Remove duplicates to get a deduplicated list [a, b, c] + # 2. Compile our function to work with this deduplicated list + # 3. At runtime, convert incoming arguments with duplicates to the deduplicated form + # 4. Pass the deduplicated arguments to our compiled function + # + # To do this, we need two helper functions: + # + # - remove_dupe_args: Converts [a, b, a, c] -> [a, b, c] + # - add_dupe_args: Converts [a, b, c] -> [a, b, a, c] + # + # For our example [a, b, a, c], we track: + # + # - seen_args = {a: 0, b: 1, c: 2} (maps each unique arg to its first position) + # - add_dupe_map = [0, 1, 0, 2] (tells us how to reconstruct the original list) + # - keep_arg_mask = [True, True, False, True] (tells us which args to keep when deduplicating) + + seen_args: dict[Tensor, int] = {} + # Implicitly map duped arg position (list index) to de-duped arg position + keep_arg_mask: list[bool] = [] + add_dupe_map: list[int] = [] + duped_arg_len = len(flat_args) + + j = 0 # index into deduped_flat_args + for t in flat_args: + if isinstance(t, torch.Tensor): + if t in seen_args: + keep_arg_mask.append(False) + add_dupe_map.append(seen_args[t]) + continue + seen_args[t] = j + + keep_arg_mask.append(True) + add_dupe_map.append(j) + j += 1 + assert len(add_dupe_map) == duped_arg_len, ( + f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" + ) + + self.keep_arg_mask = keep_arg_mask + self.add_dupe_map = add_dupe_map + + deduped_flat_args = self.remove_dupe_args(flat_args) + # TODO: instead of arbitrarily removing args, it might be useful to + # have a record that these were duped, perhaps as a mutable attribute + # on the kept arg? Do this if someone needs it + deduped_flat_args_descs = self.remove_dupe_args(flat_args_descs) + + # Update our input metadata to remove duped input metadata. + updated_fw_metadata = remove_dupe_metadata( + fw_metadata, keep_arg_mask, add_dupe_map + ) + + if ( + tracing_context := TracingContext.try_get() + and aot_config.aot_autograd_arg_pos_to_source + ): + # TODO(voz): This structure is 1:1, we could consider an alternate structure like + # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there, + # which feels like needless complexity for a tiny bit of efficiency at this point. + for dupe_arg_pos, (kept_pos, keep_arg) in enumerate( + zip(add_dupe_map, keep_arg_mask) + ): + if not keep_arg: + dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[ + dupe_arg_pos + ] + kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[ + kept_pos + ] + tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined] + DuplicateInputs(kept_arg_source, dupe_arg_source) + ) + + @simple_wraps(flat_fn) + def wrapped_flat_fn( + *args: FxValue, + ) -> tuple[list[FxValue], list[AOTOutput]]: + outs, out_descs = call_and_expect_output_descs( + flat_fn, self.add_dupe_args(args) + ) + return outs, out_descs + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + without_output_descs(wrapped_flat_fn), + flat_args_descs=deduped_flat_args_descs, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*deduped_flat_args) + assert ref_fw_metadata == updated_fw_metadata, ( + f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" + ) + + return ( + wrapped_flat_fn, + deduped_flat_args, + deduped_flat_args_descs, + updated_fw_metadata, + ) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if not self.needs_post_compile: + return compiled_fn + + @wraps(compiled_fn) + def wrapped_compiled_fn(args: list[Any]): + deduped_args = self.remove_dupe_args(args) + args.clear() + return compiled_fn(deduped_args) + + wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + # This can be uncommented when we properly guard for duplicates, + # but right now we must not do it. + # if not config.debug_assert: + # return wrapped_compiled_fn + + @wraps(wrapped_compiled_fn) + def debugged_compiled_fn(args): + # Test that the computed remove/add arg functions are an inverse + new_args = self.add_dupe_args(self.remove_dupe_args(args)) + seen: dict[Any, None] = {} + for i, (x, y) in enumerate(zip(new_args, args)): + seen[y] = None + assert x is y, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would be a duplicate of " + f"{describe_input(self.add_dupe_map[i], aot_config)}", + ) + # This is only an error if there is metadata mutation on both of + # the duped arguments; in this case, we need to know what order + # the metadata mutation applies in. You'll get the correct result + # otherwise, because a graph that assumes distinct inputs works if + # you dupe the inputs (the gradient contributions from each input + # will get summed up appropriately.) + # + # TODO: work out how to setup this assert correctly + """ + assert len(seen) == unique_args, format_guard_bug_msg(aot_config, + f"there would be {unique_args} distinct arguments" + ) + """ + return wrapped_compiled_fn(args) + + debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + return debugged_compiled_fn + + +# This layer handles the situation where you have two inputs that alias each other, +# and one of the inputs is mutated. +# We need to take special care to ensure that the mutation is applied to the other aliases in the graph. +# +# pre-condition: AOTDedupWrapper has already run. +# (This function will in theory work if there are duplicate args. +# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs +# would cause us to hit that path more frequently). +@dataclass +class AOTSyntheticBaseWrapper(CompilerWrapper): + # Currently, the only reason we need to plumb this bool is because + # the synthetic base code prohibits more cases in the autograd case than the inference case. + trace_joint: bool # TODO: refactor trace_joint + needs_post_compile: bool = True + aliased_arg_idx_with_metadata_mutations: list[int] = field(default_factory=list) + + def pre_compile( + self, + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + is_inference = not self.trace_joint + ( + flat_args_with_synthetic_bases, + flat_args_descs_with_synthetic_bases, + synthetic_base_info, + ) = merge_view_inputs( + aot_config, + flat_args, + flat_args_descs, + fw_metadata.input_info, + is_inference=is_inference, + ) + + # Happy path: we don't need synthetic bases + if synthetic_base_info is None: + self.needs_post_compile = False + return flat_fn, flat_args, flat_args_descs, fw_metadata + + # export path: ban synthetic bases for now, add later if requested. + if requires_subclass_dispatch(flat_args, fw_metadata): + raise RuntimeError( + """\ + Encountered aliased inputs that are mutated in the graph, but at least one input/output + to the graph is a tensor subclass. This is not supported today. You can try to + remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + if aot_config.is_export: + raise RuntimeError( + f"""\ + Encountered aliased inputs that are mutated in the graph you are trying to export. + This functionality is currently not supported. If needed, please file a github issue. + + synthetic_base_info={str(synthetic_base_info)} + + fw_metadata={str(fw_metadata)} + """ + ) + + assert len(fw_metadata.input_info) == len(synthetic_base_info) + + # Update our forward metadata to take synthetic bases into account + ( + fw_metadata_updated, + aliased_arg_idx_with_metadata_mutations, + ) = create_synthetic_base_metadata( + fw_metadata, + synthetic_base_info, + flat_args, + flat_args_with_synthetic_bases, + flat_args_descs_with_synthetic_bases, + ) + # Save old input args for post-compile + self.old_input_info = fw_metadata.input_info + + self.aliased_arg_idx_with_metadata_mutations = ( + aliased_arg_idx_with_metadata_mutations + ) + replay_views = config.view_replay_for_aliased_outputs + + def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]: + f_args_inner = [] + # pyrefly: ignore [not-iterable] + for inner_idx_or_tuple in synthetic_base_info: + if isinstance(inner_idx_or_tuple, int): + f_args_inner.append(primals[inner_idx_or_tuple]) + else: + inner_base_idx, view_tensor = inner_idx_or_tuple + base = primals[inner_base_idx] + view_arg = gen_alias_from_base( + base, + view_tensor, + view_tensor.requires_grad, + replay_views=replay_views, + ) + f_args_inner.append(view_arg) + return f_args_inner + + @simple_wraps(flat_fn) + def wrapped_flat_fn(*args): + unpacked_args = _unpack_synthetic_bases(args) + # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases) + # is to relieve the downstream logic from having to reason about mutations on inputs that alias + # each other, by replacing aliased inputs with a synthetic base. + # One area where this breaks down a bit however is if one of those aliased inputs + # experienced a metadata mutation. + # We are now obligated to reapply the metadata mutation directly to the user's input; + # it isn't enough to apply mutations back to the synthetic base in the downstream logic. + # + # The way we handle this is by pretending that those aliased inputs that experience metadata mutations + # are additional outputs in the user's forward function. + # The downstream logic will just treat these as "user outputs that alias inputs". + # However, we will manually grab them at runtime here, use them to reapply the metadata mutation + # to the user inputs, and not return them to the user. + aliased_args_with_metadata_mutations = [ + x + for i, x in enumerate(unpacked_args) + if i in self.aliased_arg_idx_with_metadata_mutations + ] + out, out_descs = call_and_expect_output_descs(flat_fn, unpacked_args) + if len(aliased_args_with_metadata_mutations) > 0: + # TODO: record more detailed desc information here + return (*out, *aliased_args_with_metadata_mutations), ( + *out_descs, + *( + [ + MetadataMutationAOTOutput(i) + for i in range( + len(self.aliased_arg_idx_with_metadata_mutations) + ) + ] + ), + ) + else: + return out, out_descs + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + without_output_descs(wrapped_flat_fn), + flat_args_descs=flat_args_descs_with_synthetic_bases, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*flat_args_with_synthetic_bases) + assert ref_fw_metadata == fw_metadata_updated, ( + f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, " + f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}" + ) + return ( + wrapped_flat_fn, + flat_args_with_synthetic_bases, + flat_args_descs_with_synthetic_bases, + fw_metadata_updated, + ) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if not self.needs_post_compile: + return compiled_fn + + is_inference = not self.trace_joint + + @wraps(compiled_fn) + def wrapped_compiled_fn(args): + # TODO: this sure seems expensive to run at runtime (which + # post_compile seems to imply it does?!) + args_with_synthetic_bases, _, synthetic_base_info = merge_view_inputs( + aot_config, args, None, self.old_input_info, is_inference=is_inference + ) + assert synthetic_base_info is not None + aliased_args_w_metadata_mutations = [ + args[i] for i in self.aliased_arg_idx_with_metadata_mutations + ] + num_aliased_args_with_metadata_mutations = len( + aliased_args_w_metadata_mutations + ) + args.clear() + outs = compiled_fn(args_with_synthetic_bases) + if num_aliased_args_with_metadata_mutations > 0: + # This code does not handle **all** input metadata mutations. + # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases + # (which only happens if at least one aliased input experienced a data mutation). + # e.g: + # def f(a, b): + # a.mul_(2) + # b.t_(1, 0) + # f(x.view(2, 2), x.view(2, 2)) + mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:] + user_outs = outs[:-num_aliased_args_with_metadata_mutations] + for inp, mutated_inp in zip( + aliased_args_w_metadata_mutations, mutated_metadata_inps + ): + inp.as_strided_( + mutated_inp.size(), + mutated_inp.stride(), + mutated_inp.storage_offset(), + ) + return user_outs + return outs + + return wrapped_compiled_fn + + +# Note [Handling mutations on an input that aliases other inputs] +# The easiest example to show-case this edge case is here: +# +# def f(a, b): +# a.mul_(2) +# out = a + b +# return out +# b = torch.ones(...) +# a = b.view(-1) +# f(a, b) +# +# In this situation, if a and b happened to be aliased, we need to trace something different! +# Suppose we had b = a.view(-1) +# (In this case, that means that `a._base is b`) +# +# We need to ensure that the aliasing relationship between a and b is preserved. +# We do that detecting the specific situation above (mutate an input that aliases another input), +# and when we do that, we create a synthetic base argument. Then inside of the traced forward, +# we regenerate a and b off of that base. +# The complete example of the transformed function looks like this: +# +# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views +# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph +# def traced_forward(base): +# a = base.as_strided(...) +# b = base.as_strided(...) +# a_updated = a.mul(2) +# base_updated = torch.as_strided_scatter(base, a_updated, ...) +# b_updated = base_updated.as_strided(...) +# out = a_updated + b_updated +# return a_updated, out +# +# def compiled_fn(a, b): +# // we detect that a is the "differentiable base" here +# base = a +# // In other situations, we might do either: +# // (1) a and b are both views off of some larger differentiable base +# // assert a._base is b._base and a._base is not None +# // base = a._base +# // (2) a and b both don't require gradients. Create a base from the storage +# // assert a._base is None and b._base is None +# // base = torch.Tensor(a.storage()) +# a_updated, out = traced_forward(base) +# a.copy_(a_updated) +# return out +# +# This function: +# (1) Merges input views into a synthetic base argument, when any of those input views are mutated +# (2) Returns metadata telling the autograd.Function how to modify their arguments properly, +# to respect the new calling convention. +# +# The calling convention is as follows. +# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base. +# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN], +# Where the ordering of the bases is determined from the ordering of the original view args. +# baseA will come before baseB if the earliest original argument coming from baseA +# showed up earlier in the argument list than the earliest original argument coming from baseB. +# +# Example, given some tensors a, b, c, d +# call site: +# f(a, c.view(-1), b.view(-1), b, c, d) +# Modified argument list: +# c_base comes first because the first c view came earlier in arg list than the first b view +# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases +# b_base = torch.Tensor(b.storage()) +# c_base = torch.Tensor(c.storage()) +# f(c_base, b_base, a, d) +def merge_view_inputs( + aot_config: AOTConfig, + fwd_inputs: list[Any], + # This is None when called at runtime from post_compile closure + fwd_inputs_descs: Optional[list[AOTInput]], + mutated_input_info: list[InputAliasInfo], + *, + # The autograd case currently has more restrictions than the inference case. + is_inference: bool, +) -> tuple[ + list[Any], list[AOTInput], Optional[list[Union[int, tuple[int, torch.Tensor]]]] +]: + if fwd_inputs_descs is None: + fwd_inputs_descs = [DummyAOTInput(i) for i in range(len(fwd_inputs))] + + def _are_differentiable_views(view1, view2): + if view1 is view2: + return True + if view1._base is None and view2._base is None: + return False + if view1._base is view2._base or view1._base is view2 or view1 is view2._base: + return True + return False + + def _same_dtype_views(view1, view2): + if view1.dtype != view2.dtype: + return False + if view1._base is not None and view1.dtype != view1._base.dtype: + return False + if view2._base is not None and view2.dtype != view2._base.dtype: + return False + return True + + assert len(fwd_inputs) == len(mutated_input_info) + if not [info for info in mutated_input_info if info.mutates_data]: + # Return early when there are no mutations. + return fwd_inputs, fwd_inputs_descs, None + + storage_ref_to_idx: dict[StorageWeakRef, list[int]] = collections.defaultdict(list) + base_args = [] + other_args = [] + base_args_descs = [] + other_args_descs = [] + for i, (inpt, source) in enumerate(zip(fwd_inputs, fwd_inputs_descs)): + if isinstance(inpt, Tensor): + storage_ref = StorageWeakRef(inpt.untyped_storage()) + storage_ref_to_idx[storage_ref].append(i) + else: + other_args.append(inpt) + other_args_descs.append(source) + # Note [Synthetic Base Info Metadata] + # This list contains metadata that tells you what the i'th argument in the inner calling convention should be. + # It's either: + # - another int (corresponding to the index in the argument list of the element from the outer calling convention) + # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx]) + # idx corresponds to which synthetic base from the outer calling context to view + inner_calling_convention_meta: dict[int, Union[int, tuple[int, torch.Tensor]]] = {} + for aliased_input_indices in storage_ref_to_idx.values(): + if len(aliased_input_indices) <= 1 or not any( + # We only care about mutations that affect all aliases, + # so metadata mutations on an input doesn't require us to do synthetic base handling. + mutated_input_info[inpt_idx].mutates_data + for inpt_idx in aliased_input_indices + ): + other_args.extend( + fwd_inputs[curr_idx] for curr_idx in aliased_input_indices + ) + other_args_descs.extend( + fwd_inputs_descs[curr_idx] for curr_idx in aliased_input_indices + ) + continue + + # Here, we attempt to do a more complicated check to detect false aliasing + # (e.g. if all the tensors have the same storage, but don't actually overlap) + # In theory, we could have a large group of tensors that all share storages, where only *some* of them + # have overlapping memory. + # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair + # of tensors in the current group that shares a storage is non-overlapping. + aliased_input_indices_no_false_sharing = compute_overlapping_inputs( + aot_config, fwd_inputs, aliased_input_indices + ) + if len(aliased_input_indices_no_false_sharing) <= 1: + other_args.extend( + fwd_inputs[curr_idx] for curr_idx in aliased_input_indices + ) + other_args_descs.extend( + fwd_inputs_descs[curr_idx] for curr_idx in aliased_input_indices + ) + continue + + # We detected an input that was mutated, AND aliases with another input. + # we need to replace this set of aliased inputs with a single synthetic base. + # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases + # and error out. We can fix them later. + # These checks are transitive, so we don't need to check every pair. + for idx1, idx2 in zip( + aliased_input_indices, aliased_input_indices[1:], strict=False + ): + view1 = fwd_inputs[idx1] + view2 = fwd_inputs[idx2] + # The "inputs that are aliased but have different differentiable bases" case + # is more complicated and hopefully pretty rare. Not currently handled. + if not is_inference: + assert _are_differentiable_views(view1, view2), ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) + # Regenerating views when reinterpreting complex / real tensors seems non-trivial, + # not handling for now + assert _same_dtype_views(view1, view2), ( + "aot_autograd() does not yet handle input mutations on views with different dtypes." + ) + non_none_bases = [ + (i, fwd_inputs[i]._base) + for i in aliased_input_indices + if fwd_inputs[i]._base is not None + ] + aliases_with_none_bases = [ + fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None + ] + synthetic_base_desc: AOTInput + if len(non_none_bases) == 0: + # Case where none of the aliases have a ._base + # we generate a synthetic base without gradients, and generate views off of it + # We hit this case when we have input tensors to the graph that share a storage, + # but do not have a ._base field. + # Wondering when we hit this case? + # The _base field simply says that autograd knows about the aliasing relationship, + # but sometimes we create tensors which are aliased out of the same storage but guaranteed + # to be disjoint. In these cases, we will skip setting up the _base relationship + # for performance reasons (because the fact that the tensors share the same storage + # is unobservable unless you (1) do naughty things with resize_/as_strided + # or (2) look at the storage--as we are doing here.) + # One particular example of this is optimizer steps on the LSTM module: + # LSTM parameters are packed into a contiguous storage for efficiency reasons when + # calling cuDNN kernels, so when these parameters get passed to the optimizer we will + # find they share the same storage, but do not have _base set since they are all disjoint. + # + # NOTE: There is one case where this is unsafe: + # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily + # the same shape as the "actual" base that the tensor came from. + # For the most part this is fine, because we always use as_strided() + # to generate the original aliased inputs again. + # If we were to use view-replay though, this could cause the aliased views + # to have incorrect sizes. + example_idx = aliased_input_indices[0] + example_alias = fwd_inputs[example_idx] + # Note that this function is reused at both trace time and runtime. + # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor. + synthetic_base = torch.empty( + (0,), dtype=example_alias.dtype, device=example_alias.device + ) + # We don't actually have a convenient way of going from storage -> tensor, + # So using set_() here (we suffer some minor overhead, but this case is rare). + synthetic_base.set_(example_alias.untyped_storage()) + synthetic_base_desc = SyntheticBaseAOTInput(fwd_inputs_descs[example_idx]) + else: + # Case where all of the aliases require gradients, and have the same _base. + i, synthetic_base = non_none_bases[0] + synthetic_base_desc = ViewBaseAOTInput(fwd_inputs_descs[i]) + for _, other_base in non_none_bases[1:]: + assert other_base is synthetic_base, ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) + for alias in aliases_with_none_bases: + assert alias is synthetic_base, ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) + base_args.append(synthetic_base) + base_args_descs.append(synthetic_base_desc) + for curr_view_idx in aliased_input_indices: + curr_view = fwd_inputs[curr_view_idx] + base_idx = len(base_args) - 1 + # We store just enough info here so that we can regenerate the view later. + # Regeneration: curr_view._view_func(args[base_idx]) + inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view) + if len(base_args) == 0: + assert len(other_args) == len(fwd_inputs) + # If no synthetic bases are necessary, just return the original inputs. + return fwd_inputs, fwd_inputs_descs, None + else: + from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr + + def make_hashable(arg): + if isinstance(arg, torch.SymInt): + # Since only nested SymInt objects can be hashed, we wrap them with + # SymIntEqByExpr, which is a hashable wrapper of SymInts. + return SymIntEqByExpr(arg) + return arg + + # Otherwise, return: + # (1) The new args according to the updated calling convention: (synthetic_bases, other_args) + # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention. + # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention. + args_to_functionalization = base_args + other_args + args_to_functionalization_descs = base_args_descs + other_args_descs + + # Map each argument into its old index. + # There may be some repeated arguments, so we collect their indices in a list. + arg_to_old_idx_map = collections.defaultdict(list) + for i, arg in enumerate(fwd_inputs): + arg_to_old_idx_map[make_hashable(arg)].append(i) + # Reverse the list of each argument, so that we can easily pop them one-after-the-other in order. + for hashable_arg in arg_to_old_idx_map: + arg_to_old_idx_map[hashable_arg] = list( + reversed(arg_to_old_idx_map[hashable_arg]) + ) + + for i, other_arg in enumerate(other_args): + new_idx = len(base_args) + i + old_idx = arg_to_old_idx_map[make_hashable(other_arg)].pop() + inner_calling_convention_meta[old_idx] = new_idx + + # post process into a list + post_processed_calling_convention_meta: list[ + Union[int, tuple[int, torch.Tensor]] + ] = [-1 for _ in range(len(inner_calling_convention_meta))] + for k, v in inner_calling_convention_meta.items(): + post_processed_calling_convention_meta[k] = v + # Quick assert: every argument in the inner calling convention should be accounted for. + for x in post_processed_calling_convention_meta: + assert x != -1 + return ( + args_to_functionalization, + args_to_functionalization_descs, + post_processed_calling_convention_meta, + ) + + +# Note: [Backward graph lazy lowering] +# After AOTDispatch traces the backward for graphs requiring autograd, we will lower the graph lazily, +# unless we suspect that inductor might specialize and insert additional guards. When we do lazy +# lowering, we stash the AOT backward graph (bw_module) in this class. +# +# Lowering passes are performed on a deepcopy of this bw_module due to compatibility +# with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645. +@dataclass +class AutogradLazyBackwardCompileInfo: + bw_module: Callable + placeholder_list: list[Any] + saved_context: Optional[TracingContext] + saved_compile_context: Optional[CompileContext] + + +# On an AOT Autograd cache hit, we already have a lowered backward, so there is usually +# no need to keep information around for a new lazy compilation. Except for compiled autograd, +# which wants to retrace this backward into a larger graph, and it needs the graph module to do so. +@dataclass +class CachedAutogradLazyBackwardCompileInfo: + bw_module_fn: Callable + + +def _raise_if_functorch_active(): + # not ideal but prevent the user from seeing a nasty traceback - See #138422 + stack = torch._C._functorch.peek_interpreter_stack() + torch._check( + stack is None, + lambda: ( + "It looks like you're trying to call a compiled backward function within vmap/grad/vjp, " + "which isn't supported. Try wrapping vmap inside torch.compile, or skip compiling the " + "backward function." + ), + ) + + +# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. +def _backward_prologue_functional( + ctx_saved_tensors, ctx_symints, metadata, maybe_subclass_metadata, *flat_args +): + # Calling convention: we expect a grad_out passed to the backward: + # - for every output of the fw that does *not* alias an input or graph intermediate + # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations) + # - for every graph intermediate that we need to use to generate an output later. + # The other outputs in the autograd.Function.forward that do *not* show up in the backward include: + # - outputs that alias inputs or graph intermediates + # - updated inputs due to metadata-only mutations. + # We need to return them in the forward, but ensure that they all do not get gradients in the backward, + # and we filter them out here before passing the remaining grad_outputs into the compiled backward. + _raise_if_functorch_active() + + num_intermediate_bases = metadata.num_intermediate_bases + num_mutated_runtime_inps = metadata.num_mutated_inp_runtime_indices + expected_grad_outs = ( + metadata.num_outputs + num_mutated_runtime_inps + num_intermediate_bases + ) + deterministic = metadata.deterministic + global_deterministic = torch.are_deterministic_algorithms_enabled() + if deterministic is not None: + torch._check( + not (not deterministic and global_deterministic), + lambda: ( + "This compiled backward function is being run with " + "torch.use_deterministic_algorithms(True), " + "but it was previously generated during the forward function while " + "torch.use_deterministic_algorithms(False) was set." + ), + ) + + assert len(flat_args) == expected_grad_outs + out_info = metadata.output_info + + inp_tangents, out_tangents, intermediate_base_tangents = ( + flat_args[:num_mutated_runtime_inps], + flat_args[ + num_mutated_runtime_inps : num_mutated_runtime_inps + metadata.num_outputs + ], + flat_args[num_mutated_runtime_inps + metadata.num_outputs :], + ) + # input_info contains info on *every* input, + # But in the backward(), we are only given grad outputs for every mutated input + # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad + input_info = metadata.input_info + inp_tangents_filtered = [ + x + for x, info_idx in zip( + inp_tangents, + metadata.mutated_inp_runtime_indices, + ) + if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad + ] + # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates + out_tangents_filtered = [ + x + for x, info in zip(out_tangents, out_info) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + # intermediate bases always require gradients, and always participate in the backward graph. + flat_bw_args_with_grads = [ + *inp_tangents_filtered, + *out_tangents_filtered, + *intermediate_base_tangents, + ] + num_flat_bw_args_with_grads = len(flat_bw_args_with_grads) + + # sanity asserts + # metadata_only_inps = [ + # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) + # if not input_info[info_idx].mutates_data + # ] + # aliased_outputs = [ + # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] + # assert all(x is None for x in metadata_only_inps) + # assert all(x is None for x in aliased_outputs) + # TODO: replace this with FunctionalizedRngRuntimeWrapper + rng_args = [] + if metadata.is_rng_op_functionalized: + # Add the seed and offset to args + rng_args = CUDARngStateHelper.get_torch_state_as_tuple() + + bw_tokens = [None] * metadata.num_backward_tokens + + # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first + # in the bw output order. + + # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls + # There are tests that count these calls, saving to var. + num_ctx_saved_tensors = len(ctx_saved_tensors) + all_args = [ + *ctx_symints, + *ctx_saved_tensors, + *flat_bw_args_with_grads, + *bw_tokens, + *rng_args, + ] + del ctx_saved_tensors + + # Note: [AOTAutograd Backward Guards] + # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph. + # Doing so requires us to "guess" about some of the metadata of our grad_outputs. + # + # In particular: if an output to the forward is a plain tensor or a subclass, + # its corresponding grad_output in the backward **may or may not** be + # a plain tensor or a subclass. The main cases are: + # (1) If an output is a plain tensor, its grad_out will also be a plain tensor, + # *unless* the output is used in some subclass compute later in the forward graph, + # which will cause its grad_output to become a subclass + # (2) If an output is a subclass, its grad_out will also be a subclass, + # *unless* the output of the forward did not actually participate in the gradient computation, + # in which case autograd will insert a plain tensor of zeros for the grad_output. + # We could avoid this case with `torch.autograd.Function.set_materialize_grads`, + # although this is not turned on today in AOTAutgrad and would require more work. + # + # Today, we make a guess on subclass-ness based on the above examples, + # and hard-error in the backward if we guessed wrong. + # + # In the future, we should add backward guards that would allow us to + # properly handle this case instead of erroring: we would need to retrace the backward graph, + # since we might produce an entirely different trace if our grad_outputs are subclass or not. + del flat_bw_args_with_grads + + tangents_start_idx = ( + len(all_args) - num_flat_bw_args_with_grads - len(rng_args) - len(bw_tokens) + ) + assert tangents_start_idx == len(ctx_symints) + num_ctx_saved_tensors + tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens) + + # TODO: figure out how to refactor the backward properly + # so I can use aot_dispatch_subclass_wrapper() here. + if maybe_subclass_metadata is not None: + tangents = all_args[tangents_start_idx:tangents_end_idx] + + if len(tangents) != len(metadata.subclass_tangent_meta): + raise RuntimeError( + "The grad inputs should be same number as forward output tangents" + ) + + flat_processed_tangents = list( + itertools.chain.from_iterable( + ( + AOTDispatchAutograd.process_runtime_tangent( + t, + m, + )[1] + ) + for t, m in zip( + tangents, + metadata.subclass_tangent_meta, + ) + ) + ) + + all_args = ( + runtime_unwrap_tensor_subclasses( + all_args[:tangents_start_idx], + # SymInts that are inputs to the backward graph are + # already included in the "all_args" list. + # Any symints coming from tensor subclasses should always + # come from primals, and so they will show up as extra + # arguments to the forward graph, and they will be saved + # as activation in the backward graph. + append_symints=False, + ) + + flat_processed_tangents + + runtime_unwrap_tensor_subclasses( + all_args[tangents_end_idx:], + append_symints=False, + ) + ) + else: + all_args = [ + ( + AOTDispatchAutograd.process_runtime_tangent( + t, + metadata.subclass_tangent_meta[i - tangents_start_idx], + )[0] + if (tangents_start_idx <= i < tangents_end_idx) + else t + ) + for i, t in enumerate(all_args) + ] + + # Backward with forward inputs mutations is not supported in double backward. + if ( + torch.is_grad_enabled() + and metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw + ): + raise RuntimeError( + "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True" + ) + + return all_args + + +def initialize_rng_states( + num_rng: int, + graphsafe_idx: int, + fwd_rng_states: list[torch.Generator], + bwd_rng_states: list[torch.Generator], +): + """ + Initialize the cudagraph safe rng states. + + Initialization of rng states should have a few properties: + - the initialization for each rng state should be independent + - the initialization should be deterministic + - the initialization should be based off current rng state, so that independent graphs do not + have equal rng behavior + + We defer initialization of rng states until runtime because compilation is wrapped + with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations + do not give equal randomness. + """ + with torch.utils._python_dispatch._disable_current_modes(): + seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu") + fwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + bwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + + +# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. +def _backward_epilogue_functional( + metadata, maybe_subclass_metadata, out, *, make_subclass_override=None +): + # Toss out the backward output tokens + num_bw_tokens = metadata.num_backward_tokens + if num_bw_tokens > 0: + out = out[:-num_bw_tokens] + + # TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile + out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue( + metadata, out, offset_index=len(out) - 1 + ) + out = tuple(out) + + # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. + if maybe_subclass_metadata is not None: + assert maybe_subclass_metadata.grad_input_metas is not None + outs_wrapped = wrap_tensor_subclasses( + out, + subclass_metas=maybe_subclass_metadata.grad_input_metas, + included_subclass_symints=True, + is_runtime=True, + make_subclass_override=make_subclass_override, + ) + return outs_wrapped + return out + + +def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryFormatMeta): + if memory_format.memory_format is not None: + # Coerce to torch.memory_format + if not x.is_contiguous(memory_format=memory_format.memory_format): + x = x.contiguous(memory_format=memory_format.memory_format) + return x + + expected_size = memory_format.size + assert expected_size is not None + expected_stride = memory_format.stride + assert expected_stride is not None + # Expected size and stride are static ints + # ok to use == to compare runtime tensor strides and shapes + + if x.shape == expected_size and x.stride() == expected_stride: + # Runtime tangent size and stride are the same as expected, no need to coerce + return x + + # Empty_strided creates a raw Tensor. + # We are guaranteed that only raw Tensors has expected size and stride. + # Subclasses have only expected memory_format. + restrided = torch.empty_strided( + size=expected_size, + stride=expected_stride, + dtype=x.dtype, + device=x.device, + layout=x.layout, + requires_grad=x.requires_grad, + ) + restrided.copy_(x) + return restrided + + +@contextlib.contextmanager +def _disable_saved_tensors_hooks(): + error_message = ( + "Saved tensors hooks were specialized as GraphModules." + "In this case aot_autograd inlines them in forward and backward graph " + "and disables them during runtime of aot_autograd compiled region." + "If you see this error, that means that there is some unexpected push or pop manipulation " + "during aot_autograd compiled region runtime." + "Compilation with different hooks must result in recompilation." + ) + fail_if_non_empty = False + maybe_prev_message = None + try: + maybe_prev_message = ( + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ) + torch._C._autograd._saved_tensors_hooks_disable( + error_message, fail_if_non_empty + ) + yield + finally: + if maybe_prev_message is None: + torch._C._autograd._saved_tensors_hooks_enable() + else: + torch._C._autograd._saved_tensors_hooks_disable( + maybe_prev_message, fail_if_non_empty + ) + + +@dataclass +class SerializableCompiledFunction: + """ + Represents a result of AOTDispatch after calling the inner compiler + that can be serialized + """ + + compiled_fn: Callable + serialize_fn: Callable + + def __init__(self, compiled_fn: Callable, serialize_fn: Callable): + self.compiled_fn = compiled_fn + self.serialize_fn = serialize_fn + # Equivalent to functools.wraps + functools.update_wrapper( + self, + compiled_fn, + assigned=("__doc__", "__annotations__", "__type_params__"), + ) + + def serialize(self) -> Any: + return self.serialize_fn() + + def __call__(self, *args, **kwargs): + return self.compiled_fn(*args, **kwargs) + + +# This is wrapped in a class just for namespacing purposes +# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly +class AOTDispatchAutograd: + @staticmethod + def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta]): + if not isinstance(x, torch.Tensor): + return x, [x] + + if isinstance(x, FakeTensor): + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) + return x, [x] + + expected_type: Optional[type] = torch.Tensor + expected_meta = None + if isinstance(meta, SubclassCreationMeta): + expected_type = meta.original_subclass_type + expected_meta = meta.meta + + runtime_type = type(x) + # When we're inside compiled autograd's AOTDispatcher step, + # regular Tensors look like FunctionalTensors. + # Tensor subclasses still look like Tensor subclasses though. + if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): + runtime_type = torch.Tensor + + runtime_meta = None + runtime_subclass_keys: Sequence[str] = [] + + if is_traceable_wrapper_subclass(x): + runtime_subclass_keys, runtime_meta = x.__tensor_flatten__() + + def maybe_coerce(x): + same_type: bool = expected_type == runtime_type + same_meta: bool = expected_meta == runtime_meta + + if same_type and same_meta: + return x + + if not hasattr(x, "__coerce_same_metadata_as_tangent__"): + return None + + if same_type: + # Backward Compatibility, as some Subclass impls can have original 1-arg function. + return x.__coerce_same_metadata_as_tangent__(expected_meta) + + return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type) + + # Coerce to expected type and metadata + orig_x = x + x = maybe_coerce(x) + if x is None: + raise RuntimeError( + f""" +During the backward, we encountered a tensor subclass where we guessed its +metadata incorrectly. + +Expected metadata: {str(expected_meta)}, expected type: {str(expected_type)} + +Runtime metadata: {str(runtime_meta)}, runtime type: {str(runtime_type)} + +shape: {str(orig_x.shape)} +To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__. +""" + ) + + # Coerce to expected memory format + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) + + if not is_traceable_wrapper_subclass(x): + return x, [x] + + assert isinstance(meta, SubclassCreationMeta) + if orig_x is not x: + runtime_subclass_keys = x.__tensor_flatten__()[0] + + assert len(meta.attrs) == len(runtime_subclass_keys) + leaves = [] + for attr, attr_meta in meta.attrs.items(): + elem = getattr(x, attr) + new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent( + elem, attr_meta + ) + if new_elem is not elem: + setattr(x, attr, new_elem) + leaves.extend(elem_leaves) + + return x, leaves + + @staticmethod + def post_compile( + compiled_fw_func, # fw_module after compilation + wrappers + compiled_bw_func, # bw_module after compilation + wrappers + maybe_subclass_meta: Optional[SubclassMeta], + num_symints_saved_for_bw_: int, + backward_state_indices: list[int], + disable_amp: bool, + indices_of_inps_to_detach: list[int], + lazy_backward_info: Optional[ + Union[ + AutogradLazyBackwardCompileInfo, + CachedAutogradLazyBackwardCompileInfo, + ] + ], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, # runtime metadata + try_save_cache_entry: Optional[Callable], # Serialization function + ): + # For additional context see Note [CUDA Graph Safe RNG Functionalization] + # Each pair forward, backward rng states must be equal prior to its invocation on any + # iteration of forward, backward. Because they are initialized equal, and are computing the same rng op, + # running forward then backward advances them the same amount and keeps them equal. + # However, a user may invoke multiple forwards, then backwards, such that they are not in sync. + # Initially we have: + # fwd_state0 == bwd_state0. + # Lets say we run: + # fwd0: fwd_state0 -> fwd_state1 + # fwd1: fwd_state1 -> fwd_state2 + # fwd2: fwd_state2 -> fwd_state3 + # If we now invoke bwd2, + # we need to update bwd_state equal to the rng that was observed in fwd2. + # we save the rng_state fwd_state2 in forward because we detect that it is not the + # current backward state and therefore would not be accessible if we do not save it. + # Similarly, if we are going to update the backward state to a new value, and there is a pending + # forwards which needs its current state, we will save it. + # Within the autograd context, we keep track of the curr iteration so that on backward + # we know what the generator state must be before the backward is run. + num_rng = fw_metadata.num_graphsafe_rng_states + graphsafe_idx = fw_metadata.graphsafe_rng_state_index + fwd_rng_states: list[torch.Generator] = [] + bwd_rng_states: list[torch.Generator] = [] + curr_fwd_iter = itertools.count(0) + backward_state_position = 0 + pending_forwards: set[int] = set() + saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {} + + class CompiledFunction(torch.autograd.Function): + compiled_fw = compiled_fw_func + compiled_bw = compiled_bw_func + metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment] + maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta + num_symints_saved_for_bw = num_symints_saved_for_bw_ + _aot_id = aot_config.aot_id + _lazy_backward_info = lazy_backward_info + + @staticmethod + def _compiled_autograd_key(ctx): + return (ctx._autograd_function_id, *ctx.symints) + + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, *deduped_flat_tensor_args): + args = deduped_flat_tensor_args + if backward_state_indices: + bw_state = args[backward_state_indices[0]] + assert isinstance(bw_state, BackwardState) + ctx._compiled_autograd_backward_state = bw_state + + if num_rng: + if len(fwd_rng_states) == 0: + assert graphsafe_idx is not None + initialize_rng_states( + num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states + ) + + _curr_iter = next(curr_fwd_iter) + ctx._curr_iter = _curr_iter + + # if this state is not contained in the backward, + # we need to save it for when its backward pass happens + if _curr_iter != backward_state_position: + saved_backward_tensor_states[_curr_iter] = [ + rng_state.get_state() for rng_state in fwd_rng_states + ] + + pending_forwards.add(_curr_iter) + args = (*args, *fwd_rng_states) + + # There is a pretty complicated calling convention around what the compiled fw returns. + # The full list of outputs and their relative order is: + # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) + # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version + # of the original view, and not the synthetic base + # - Note that donated buffer logic requires (*saved_tensors, *saved_symints) showing up last + # in the fw output order. + fw_outs = call_func_at_runtime_with_args( + CompiledFunction.compiled_fw, + # pyrefly: ignore [bad-argument-type] + args, + disable_amp=disable_amp, + ) + + num_outputs = CompiledFunction.metadata.num_outputs + num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased + num_mutated_runtime_inps = ( + CompiledFunction.metadata.num_mutated_inp_runtime_indices + ) + num_forward_returns = CompiledFunction.metadata.num_forward_returns + + # Partitioners must put symint arguments at the end separate from tensor arguments + tensors_saved_for_backwards = fw_outs[ + CompiledFunction.metadata.tensors_saved_for_backwards_slice + ] + assert all( + isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards + ) + + def mark_dynamic_activations(activations: list[torch.Tensor]): + for ( + idx, + dims, + ) in CompiledFunction.metadata.dynamic_saved_tensors_idxs.items(): + maybe_mark_dynamic_helper(activations[idx], dims) + return activations + + # See Note [Detaching saved tensors in AOTAutograd] + ctx.save_for_backward( + *mark_dynamic_activations( + [ + x.detach() if x._is_view() else x + for x in tensors_saved_for_backwards + ] + ) + ) + symint_outs = fw_outs[ + CompiledFunction.metadata.symints_saved_for_backwards_slice + ] + assert all( + isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) + for x in symint_outs + ), str([type(x) for x in symint_outs]) + ctx.symints = symint_outs + + raw_returns = fw_outs[0:num_forward_returns] + + # Wrap all autograd.Function.forward() outputs that are aliases + # so that autograd.Function doesn't treat them as tensors + if num_mutated_runtime_inps > 0: + for i, idx in enumerate( + CompiledFunction.metadata.mutated_inp_runtime_indices + ): + # We could make this faster by only looping over inputs with metadata-only mutations + # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. + info = CompiledFunction.metadata.input_info[idx] + if info.mutates_metadata and not info.mutates_data: + raw_return_idx = i + raw_returns[raw_return_idx] = TensorAlias( + raw_returns[raw_return_idx] + ) + + if config.debug_assert: + user_mutated_inputs_raw = raw_returns[ + 0:num_mutated_runtime_inps + ] + mut_inp_infos = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutates_data or x.mutates_metadata + ] + assert len(user_mutated_inputs_raw) == len(mut_inp_infos) + + if CompiledFunction.metadata.num_unsafe_view_outputs > 0: + for idx in CompiledFunction.metadata.unsafe_view_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + o = raw_returns[raw_return_idx] + raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view( + o, o.shape + ) + + if num_outputs_aliased > 0: + for idx in CompiledFunction.metadata.aliased_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + raw_returns[raw_return_idx] = TensorAlias( + raw_returns[raw_return_idx] + ) + + if config.debug_assert: + intermediates_raw = raw_returns[ + num_mutated_runtime_inps + num_outputs : + ] + assert not any( + isinstance(x, TensorAlias) for x in intermediates_raw + ) + + # invariant: intermediate bases always require gradients, so we don't have to + # consider marking them as non-differentiable. + raw_returns_not_including_intermediate_bases = raw_returns[ + : num_mutated_runtime_inps + num_outputs + ] + raw_returns_meta = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + CompiledFunction.metadata.output_info + + fw_outs_not_requiring_grad = [ + x + for (i, x) in enumerate( + raw_returns_not_including_intermediate_bases + ) + if isinstance(x, torch.Tensor) + and not raw_returns_meta[i].requires_grad + ] + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + ctx._materialize_non_diff_grads = False + return tuple(raw_returns) + + @staticmethod + def backward(ctx, *flat_args): + all_args = _backward_prologue_functional( + ctx.saved_tensors, + ctx.symints, + CompiledFunction.metadata, + CompiledFunction.maybe_subclass_metadata, + *flat_args, + ) + + if num_rng: + nonlocal backward_state_position, bwd_rng_states + curr_backward_iter = ctx._curr_iter + retain_graph = ( + torch._C._autograd._get_current_graph_task_keep_graph() + ) + + # Save current state if we have a pending forward that needs this state + # or this state may be needed again because of retain graph + if ( + backward_state_position in pending_forwards + and backward_state_position not in saved_backward_tensor_states + and ( + backward_state_position != curr_backward_iter + or retain_graph + ) + ): + saved_backward_tensor_states[backward_state_position] = [ + rng_state.get_state() for rng_state in bwd_rng_states + ] + + # Restore saved states if needed + if curr_backward_iter in saved_backward_tensor_states: + if backward_state_position != curr_backward_iter: + for bwd_state, saved_state in zip( + bwd_rng_states, + saved_backward_tensor_states[curr_backward_iter], + ): + bwd_state.set_state(saved_state) + if not retain_graph: + del saved_backward_tensor_states[curr_backward_iter] + else: + assert backward_state_position == curr_backward_iter + + backward_state_position = curr_backward_iter + 1 + if not retain_graph: + pending_forwards.remove(curr_backward_iter) + all_args.extend(bwd_rng_states) + + def impl_fn(double_ctx=None): + out = CompiledFunction._backward_impl(ctx, all_args) + return _backward_epilogue_functional( + CompiledFunction.metadata, + CompiledFunction.maybe_subclass_metadata, + out, + ) + + needs_grad = torch.is_grad_enabled() and any( + t.requires_grad for t in all_args if isinstance(t, torch.Tensor) + ) + if needs_grad: + # double backward + return CompiledFunction._double_backward(ctx, impl_fn, all_args) + else: + return impl_fn() + + @staticmethod + def _double_backward(ctx, impl_fn, all_args): + # Ensure that the graph is connected, and error if double backward is performed. + # See comment for why once_differentiable is not sufficient: + # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107 + class CompiledFunctionBackward(torch.autograd.Function): + # CompiledFunctionBackward is not yet supported in dynamo skipfiles + _aot_id = aot_config.aot_id + + @staticmethod + # pyrefly: ignore [bad-override] + def forward(double_ctx, *unused_args): + return impl_fn(double_ctx) + + @staticmethod + def backward(double_ctx, *args): + raise RuntimeError( + "torch.compile with aot_autograd does not currently support double backward" + ) + + CompiledFunctionBackward._compiled_autograd_key = ( # type: ignore[method-assign] + CompiledFunction._compiled_autograd_key + ) + + return CompiledFunctionBackward.apply(*all_args) + + @staticmethod + def _backward_impl(ctx, all_args): + # compiled autograd reimplements this function at proxy_call_aot_backward + assert not backward_state_indices, ( + "BackwardState requires CompiledAutograd" + ) + ctx.maybe_clear_saved_tensors() + + saved_tensors_use_once = ( + not torch._C._autograd._get_current_graph_task_keep_graph() + ) + + if CompiledFunction.compiled_bw is None: + assert lazy_backward_info is not None + assert isinstance( + lazy_backward_info, AutogradLazyBackwardCompileInfo + ) + + if ( + hasattr(lazy_backward_info, "saved_context") + and lazy_backward_info.saved_context is not None + ): + assert isinstance( + lazy_backward_info.saved_context, TracingContext + ) + ddp_ctx = lazy_backward_info.saved_context.ddp_optimizer_ctx + if ddp_ctx is not None: + assert ddp_ctx.curr_bucket >= 0, ( + f"expected same # of fw and bw compiles, but found bucket {ddp_ctx.curr_bucket}" + ) + curr_fw_meta = ddp_ctx.metadata_per_bucket[ + ddp_ctx.curr_bucket + ] + # Note [DDPOptimizer and fw_metadata] + # When using the DDPOptimizer, we have a single dynamo graph (and TracingContext), + # but multiple AOTDispatcher graph. + # + # One consequence is that there will be **multiple** fw_metadata objects, one per AOT graph, + # which we stash the fw_metadata on the TracingContext. + # + # Normally what happens is that as we compile AOT graphs 1...N, we clobber the fw_metadata + # for graph i-1 when we start running AOT for graph i. + # Ordinarily this is fine, because inductor no longer needs the metadata from graph i-1. + # + # However, this is a problem for lazy compilation of the backward. During backward compilation, + # we compile the backward lazily at backward runtime, meaning that we will first compile + # backward graph N, N-1, ..., 1. + # We need to ensure that at the time inductor compiles bw graph N-1, it can access + # the corresponding fw_metadta for graph N-1. + # + # We do this by stashing a DDPOptimizerContext, which tracks: + # - the metadata of all N graphs + # - the graph we are currently compiling in our DDPOptimizer region. + ddp_ctx.curr_bucket -= 1 + lazy_backward_info.saved_context.fw_metadata = curr_fw_meta + + if not saved_tensors_use_once: + fw_metadata.bw_donated_idxs = [] + # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd` + if ( + hasattr(lazy_backward_info, "saved_context") + and hasattr(lazy_backward_info.saved_context, "fw_metadata") + and hasattr( + lazy_backward_info.saved_context.fw_metadata, # type: ignore[union-attr] + "bw_donated_idxs", + ) + ): + lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = ( # type: ignore[union-attr] + [] + ) + + bw_module = lazy_backward_info.bw_module + placeholder_list = lazy_backward_info.placeholder_list + saved_context = lazy_backward_info.saved_context + saved_compile_context = lazy_backward_info.saved_compile_context + + context = torch._C._DisableAutocast if disable_amp else nullcontext + metrics_context = get_metrics_context() + with ( + tracing(saved_context), + compile_context(saved_compile_context), + context(), + track_graph_compiling(aot_config, "backward"), + metrics_context, + dynamo_timed( + "backward._backward_impl", + phase_name="entire_backward_compile", + log_pt2_compile_event=True, + dynamo_compile_column_us="backward_cumulative_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="entire_backward_compile", + ), + callback_handler.install_callbacks( + CallbackTrigger.LAZY_BACKWARD, + str(CompileContext.current_compile_id()), + ), + ): + CompileEventLogger.compilation_metric(is_forward=False) + # See Note: [Backward graph lazy lowering] + CompiledFunction.compiled_bw = aot_config.bw_compiler( + copy.deepcopy(bw_module), placeholder_list + ) + # Maybe save cache entry + if try_save_cache_entry is not None: + try_save_cache_entry( + CompiledFunction.compiled_bw, + bw_module, + fw_metadata, + aot_config, + ) + + if ( + torch._functorch.config.donated_buffer + and not saved_tensors_use_once + and fw_metadata.bw_donated_idxs != [] + ): + torch._check( + False, + lambda: ( + "This backward function was compiled with non-empty donated " + "buffers which requires create_graph=False and retain_graph=False. " + "Please keep backward(create_graph=False, retain_graph=False) " + "across all backward() function calls, or set " + "torch._functorch.config.donated_buffer=False to disable " + "donated buffer." + ), + ) + + out = call_func_at_runtime_with_args( + CompiledFunction.compiled_bw, + all_args, + steal_args=True, + disable_amp=disable_amp, + ) + return out + + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=indices_of_inps_to_detach, + trace_joint=True, + disable_amp=disable_amp, + ).post_compile( + CompiledFunction.apply, + aot_config, + runtime_metadata=fw_metadata, + ) + + return compiled_function + + +@dataclass +class DebugAssertWrapper(CompilerWrapper): + flat_requires_grad: list[Optional[bool]] = field(default_factory=list) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + @wraps(compiled_fn) + def debug_compiled_function(args: list[Any]): + # TODO: Check aliasing relationships + # TODO: Check strides for metadata mutation + # (NB: ideally, this logic is factored out of this function and + # you move these debug checks there) + + # Check requires grad. Bad case is when we compiled with + # requires_grad = False, but input requires_grad = True + # (vice versa is OK; we compute a gradient and then throw + # it away when it hits the input.) + for i, a in enumerate(args): + can_require_grad = self.flat_requires_grad[i] + if can_require_grad is None: + assert not isinstance(a, Tensor) + elif not can_require_grad: + assert not a.requires_grad, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would not require grad", + ) + + return compiled_fn(args) + + return debug_compiled_function + + +def pre_compile( + wrappers: list[CompilerWrapper], + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[TraceFn, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + """ + Runs a sequence of wrappers on the given function and arguments. + Mutates wrappers in place. + """ + for wrapper in wrappers: + flat_fn, flat_args, flat_args_descs, fw_metadata = wrapper.pre_compile( + flat_fn, flat_args, flat_args_descs, aot_config, fw_metadata=fw_metadata + ) + return flat_fn, flat_args, flat_args_descs, fw_metadata + + +def post_compile( + wrappers: list[CompilerWrapper], + compiled_fn: Callable, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, +) -> tuple[Callable, ViewAndMutationMeta]: + """ + Runs a sequence of wrappers on the given function. Should be called after pre_compile() + """ + for wrapper in reversed(wrappers): + compiled_fn = wrapper.post_compile( + compiled_fn, aot_config, runtime_metadata=runtime_metadata + ) + return compiled_fn, runtime_metadata + + +def make_runtime_safe( + fw_metadata: ViewAndMutationMeta, + maybe_subclass_meta: Optional[SubclassMeta], +): + """ + Calls make_runtime_safe on all ViewAndMutationMetas. + Modifies both arguments. Allows ViewAndMutationMetas to + be safely cached in AOTAutogradCache. + """ + fw_metadata.make_runtime_safe() + if maybe_subclass_meta is not None: + maybe_subclass_meta.fw_metadata.make_runtime_safe() + if maybe_subclass_meta.grad_input_metas: + for meta in maybe_subclass_meta.grad_input_metas: + if isinstance(meta, SubclassCreationMeta): + meta.make_runtime_safe() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/schemas.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc03c7adb7ee1a3799b874f29f879d23055926d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/schemas.py @@ -0,0 +1,1297 @@ +# mypy: allow-untyped-defs +""" +The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes +input/output types, metadata, config, function signatures etc. +""" + +from __future__ import annotations + +import collections +import functools +import itertools +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, NewType, Optional, Protocol, TYPE_CHECKING, TypeVar, Union + +import torch +import torch.utils._pytree as pytree +from torch import SymInt, Tensor +from torch._subclasses import FakeTensor +from torch._subclasses.fake_tensor import is_fake +from torch.fx.experimental._backward_state import BackwardState +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence +from .utils import strict_zip + + +if TYPE_CHECKING: + import contextlib + from collections.abc import Callable, Iterable, Sequence + + from torch._guards import Source + from torch._inductor.output_code import OutputCode + from torch._inductor.utils import InputType + from torch._ops import OpOverload + + from .descriptors import AOTInput, AOTOutput + from .graph_capture_wrappers import JointFnHandle + + +zip = strict_zip + + +OutputType = Enum( + "OutputType", + ( + # output is not an alias + "non_alias", + # output aliases an input + "alias_of_input", + # output **is** an input tensor + "is_input", + # output has a ._base tensor, which is a graph intermediate. + # We need to return its ._base as a graph output, + # so its requires_grad info is populated correctly. + # Instructs the runtime code to regenerate the current output + # from a base tensor, graph_intermediates[base_idx] + "alias_of_intermediate_save_as_output", + # Same as above; but we don't need to explicitly add its ._base + # as a graph output, because it already **is** a graph output. + "alias_of_intermediate", + # Same as above; but the output's ._base is **already** a user output. + # Instructs the runtime code to regenerate the current output from + # a base tensor, user_outputs[base_idx] + "alias_of_intermediate_base_is_user_output", + # See Note [Intermediate Bases Optimization] + "unsafe_view_alias", + # output is an alias, but has a custom autograd.Function backward. + # In this case, we don't want to do view-replay, since we won't be able to replay the custom function. + # Instead, we'll treat this output "normally", and trace its backward into the graph. + "custom_function_view", + ), +) + + +# This class stores info about every user output. +@dataclass(frozen=True) +class OutputAliasInfo: + # Tells us if this output is: + # (1) a regular (non-aliased) output + # (2) an alias of a forward input + # (3) **is** a forward input (special case of "alias_of_input") + # (4) an alias of an intermediate (aka an alias of an output of the inner traced forward) + # (5) an alias of an intermediate, that explicitly requires returning the intermediate + # as a graph output + # (6) an alias of an intermediate, where that intermediate is also a user output + output_type: OutputType + # The raw type of the output (torch.Tensor, SymInt, etc) + raw_type: type + # If (1) above, then + # - base_idx is None + # If (2) or (3) above, then + # - Tells us that the base of this alias is user_fwd_input[base_idx] + # (This is an index into the inputs *before* we make synthetic bases) + # If (4) or (5) above, then + # - Tells us that the base of this alias is output_graph_intermediates[base_idx] + # here, this refers to the index of the *direct* traced + # If (6) above, then: + # - Tells us that the base of this alias is output_user_fwds[base_idx] + # here, this refers to the index of the *direct* traced + base_idx: Optional[int] + # If it is a Tensor, what the dynamic dims are (otherwise is None) + dynamic_dims: Optional[set[int]] + # requires_grad + requires_grad: bool + # Sequence of ViewMeta objects. + # + # Provides us the means to re-run view functions on other tensors. + # + # We need to wrap the actual list of ViewMeta with this class so that + # we compare the ViewMeta elements appropriately, i.e. their type and + # the elements returned by the `as_tuple()` call. + view_meta_sequence: Optional[ViewMetaSequence] = None + + +class MutationType(Enum): + NOT_MUTATED = 1 + MUTATED_IN_GRAPH = 2 + MUTATED_OUT_GRAPH = 3 + + +# This class tells us info about user inputs. +@dataclass(frozen=True) +class InputAliasInfo: + is_leaf: bool + mutates_data: bool + mutates_metadata: bool + mutations_hidden_from_autograd: bool + mutations_under_no_grad_or_inference_mode: bool + mutation_inductor_storage_resize: bool + mutates_storage_metadata: bool + requires_grad: bool + keep_input_mutations: bool + + def __post_init__(self): + if self.mutates_storage_metadata: + # For convenience, we guarantee that this is always true. + # In practice, If we call .set_(), then at runtime there is no need + # to additionally fix up the tensor metadata, since our runtime + # call to inp.set_(updated_inp) will already have the right metadata + assert self.mutates_metadata + + @functools.cached_property + def mutation_type(self) -> MutationType: + if ( + (not self.mutates_data) + and (not self.mutates_metadata) + and not (self.mutation_inductor_storage_resize) + ): + return MutationType.NOT_MUTATED + + if _check_if_mutation_can_be_in_graph( + self.keep_input_mutations, + self.mutates_data, + self.mutates_metadata, + self.mutations_hidden_from_autograd, + self.mutations_under_no_grad_or_inference_mode, + self.mutates_storage_metadata, + self.mutation_inductor_storage_resize, + self.requires_grad, + ): + return MutationType.MUTATED_IN_GRAPH + + return MutationType.MUTATED_OUT_GRAPH + + +@dataclass +class MemoryFormatMeta: + # For static shapes we assume tangents have the same strideness as outputs + size: Optional[Sequence[int]] = None + stride: Optional[Sequence[int]] = None + + # For dynamic shapes we assume the same memory format: contiguous, channels_last etc. + memory_format: Optional[torch.memory_format] = None + + @staticmethod + def from_tensor(t: torch.Tensor) -> Optional[MemoryFormatMeta]: + # We only memorize expected memory format for + # 1. Traceable wrapper subclasses + # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. + # 2. Dynamic shape tensors + # Support for symbolic shapes is not implemented yet. + use_memory_format: bool = ( + not torch._functorch.config.guess_tangent_strides_as_outputs + or is_traceable_wrapper_subclass(t) + ) + if not use_memory_format: + is_static_shape = True + for s in itertools.chain(t.shape, t.stride()): + if not isinstance(s, int): + is_static_shape = False + break + + use_memory_format = not is_static_shape + + if use_memory_format: + return MemoryFormatMeta( + # pyrefly: ignore [unbound-name] + memory_format=torch._prims_common.suggest_memory_format(t), + ) + + return MemoryFormatMeta( + size=t.size(), + stride=t.stride(), + ) + + +@dataclass +class PlainTensorMeta: + unwrapped_idx: int + memory_format: Optional[MemoryFormatMeta] = None + + +@dataclass +class SubclassCreationMeta: + """ + Used for AOTDispatch. + This dataclass gives us the information we need to reconstruct a tensor subclass + from our flat inputs. + Why is this important? The graph that we'd like to trace out contains flat tensor inputs, + But the user's original model may have subclass inputs and outputs. + So we need to wrap/unwrap subclasses as necessary to translate between the user's + view (subclass inps/outs), and the backend compiler's view (graph with no subclass args). + + Complications arise mostly from the fact that a subclass can hold more than one inner tensor; + So for a given subclass input/output, we need to carefully track which indices map + to the subclass tensor in the corresponding "dense-tensor-only" graph. + """ + + # In the inner graph that only takes in dense tensor inputs, + # this maps to the first index of "tensors that should go in this subclass wrapper" + flat_tensor_start_idx: int + # arg_count is inclusive of the arg_counts of any + # inner tensor subclasses: If I have a TwoTensor and + # both of its inner elements are TwoTensors, then the + # arg_count of the outer-most subclass will be 4 + arg_count: int + # Mark where or not symints were included. This flag is only used in one assertion + # in "wrap_tensor_subclasses" + included_subclass_symints: bool + # meta and attrs are produced by the subclass's __tensor_flatten__. + # We need to keep them around along with outer_size / outer_stride to plumb them + # into __tensor_unflatten__ + attrs: dict[str, Union[SubclassCreationMeta, PlainTensorMeta]] + outer_size: Iterable[Union[None, int, torch.SymInt]] + outer_stride: Iterable[Union[None, int, torch.SymInt]] + meta: Any + # Stores the original subclass itself. + # This is needed because we need the autograd metadata on the original subclass + # (this is guaranteed to be a wrapper subclass that holds a fake tensor, + # so holding onto this at runtime shouldn't leak memory) + # This field is nulled out after calling make_runtime_safe() + original_subclass: Optional[torch.Tensor] + + # Used at runtime to determine the subclass type, so we don't need to save the original subclass + original_subclass_type: Optional[type] = None + memory_format: Optional[MemoryFormatMeta] = None + + def compute_outer_size_and_stride( + self, + all_args, + *, + curr_start_idx: int, + ): + from .subclass_utils import compute_symint_placeholders + + def compute(outer, start_idx): + placeholders = compute_symint_placeholders(outer) + has_symbolic = any(placeholders) + + if has_symbolic: + start = curr_start_idx + end = start_idx + sum(placeholders) + it_args = iter(all_args[start:end]) + it_placeholders = iter(placeholders) + return pytree.tree_map_only( + lambda _: next(it_placeholders), lambda _: next(it_args), outer + ), start + len(placeholders) + else: + return outer, start_idx + + outer_size, next_idx = compute(self.outer_size, curr_start_idx) + outer_stride, _ = compute(self.outer_stride, next_idx) + return outer_size, outer_stride + + def creation_fn( + self, + all_args, + *, + is_runtime: bool, + ): + inner_tensors = {} + + curr_start_idx = self.flat_tensor_start_idx + for attr, creation_meta in self.attrs.items(): + if isinstance(creation_meta, PlainTensorMeta): + subclass = all_args[curr_start_idx] + curr_start_idx += 1 + else: + subclass = creation_meta.creation_fn( + all_args, + is_runtime=is_runtime, + ) + curr_start_idx += creation_meta.arg_count + inner_tensors[attr] = subclass + + if is_runtime: + assert self.original_subclass_type is not None + original_subclass_type = self.original_subclass_type + else: + original_subclass_type = type(self.original_subclass) + + if is_runtime: + outer_size, outer_stride = self.compute_outer_size_and_stride( + all_args, + curr_start_idx=curr_start_idx, + ) + else: + outer_size, outer_stride = self.outer_size, self.outer_stride + + rebuilt = original_subclass_type.__tensor_unflatten__( # type: ignore[attr-defined] + inner_tensors, self.meta, outer_size, outer_stride + ) + + if not is_runtime: + # After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper + # has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass. + # We don't trace through the autograd engine at runtime though, so no need + # to compute this extra metadata then! + torch._mirror_autograd_meta_to(self.original_subclass, rebuilt) # type: ignore[attr-defined] + + return rebuilt + + def make_runtime_safe(self): + def _make_size_runtime_safe(x: Union[None, int, torch.SymInt]) -> Optional[int]: + dummy = -1 + if isinstance(x, torch.SymInt): + # Replace nested ints by a dummy value (-1) as NJT ignores + # the outer_size/outer_stride at runtime. + return dummy if x.node.is_nested_int() else None + return x + + assert self.original_subclass is not None + self.original_subclass_type = type(self.original_subclass) + self.original_subclass = None + + # Note: NJT outer_size in AOTDispatcher + # `_make_size_runtime_safe` replaces any nested int with a dummy value (-1) + # to prevent serializing a SymInt at runtime. Internally, nested tensor __tensor_unflatten__ + # is designed to safely ignore this dummy value. + # For more details, see: https://github.com/pytorch/pytorch/blob/5141ade8e30c64e873e14dcc8de233da45d15025/torch/nested/_internal/nested_tensor.py#L266-L299 # noqa: B950 + self.outer_size = tuple(map(_make_size_runtime_safe, self.outer_size)) + self.outer_stride = tuple(map(_make_size_runtime_safe, self.outer_stride)) + + # Recurse on nested subclass info + for creation_meta in self.attrs.values(): + if isinstance(creation_meta, SubclassCreationMeta): + creation_meta.make_runtime_safe() + + def __post_init__(self): + # sanity assert to make sure we don't leak memory + assert is_fake(self.original_subclass) + + +# This class encapsulates all aliasing + mutation info we need about the forward graph +# See a more detailed overview of the edge case handling at +# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit +# NOTE: This class is saved in AOTAutogradCache, If you are adding elements, make sure +# they are covered by warm cache tests. +@dataclass(eq=False) +class ViewAndMutationMeta: + # length = # user inputs + # This gives us info about every input, and what sort of mutation happened to it (if any) + input_info: list[InputAliasInfo] + + # length = # user outputs + # This gives us info about every output (mostly around whether it aliases other tensors) + output_info: list[OutputAliasInfo] + + # length = the number of intermediate bases appended as outputs to the end of the forward graph. + # Note: this is not necessarily the same thing as: + # len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate]) + # Because outputs might share a ._base, or an output's ._base might itself be + # another user output (in both cases, we won't redundantly append bases to the end of the graph) + num_intermediate_bases: int + + # For inference only: instructs us to keep data-only input mutations directly in the graph + keep_input_mutations: bool + + # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) + # + (# intermediate bases) + # These are the FakeTensor (or potential SymInt) outputs that we traced from our + # metadata pass of the user's forward function. + # Their only use today is to pass them as a best-guess for tangents when tracing the joint. + # Stashing them as part of our "metadata" makes it simpler if we want to run our analysis + # pass once, and reuse the output throughout AOTAutograd + traced_tangents: list[Any] + + # TODO doc + traced_tangents_descs: list[AOTInput] + + # Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs + # They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors, + # Given a (potentially larger) list of plain torch tensors. + + # Taking subclass_inp_meta as an example: + # subclass_inp_meta[i] = j (an int) tells us: + # "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph." + # subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2) + # "The i'th user input is subclass holding two inner tensors, which are + # inputs[3] and inputs[4] of the plain-tensor graph". + + # length = # user inputs + subclass_inp_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # So, the full set of outputs to the forward graph looks something like: + # (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors) + # where the first 3 of those 4 can be subclasses + # (but not saved_for_bw tensors, since these are internal to the compiler + # and not user visible, so there's no point in wrapping/unwrapping them at runtime). + # This list contains subclass information on all of the fw graph outputs + # except for saved_for_bw_tensors. + subclass_fw_graph_out_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # length = # backward graph inputs + subclass_tangent_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # TODO: we should kill this + # (need to default it to not break internal) + is_train: bool = False + + # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) + # + (# intermediate bases) + # At runtime, we don't keep the traced_tangents around since they're not serializable. + # Instead, we keep any necessary subclass metadata necessary about each traced_tangent. + # This list is generated after calling make_runtime_safe(). + traced_tangent_metas: Optional[list[Any]] = None + + num_symints_saved_for_bw: Optional[int] = None + + # The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue + # NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode + # that is intended to be in effect prior to running the graph, in keeping with + # equivalence to eager mode. It is the responsibility of upstream graph acquisition + # to reset the grad mode to its pre-graph value prior to calling aot_autograd. + grad_enabled_mutation: Optional[bool] = None + + # Keeps track of whether `torch.use_deterministic_algorithms` was turned on + # when the forward was run. If deterministic mode was turned off during the + # forward, but is turned on during the backward call, then an error is + # raised + deterministic: Optional[bool] = None + + # Keeps track of which input indices store parameters (which we will treat as static) + static_input_indices: list[int] = field(default_factory=list) + + # Map of effect type (ex. _EffectType.ORDERED) to token. If there are + # side-effectful operators, FunctionalTensorMode will populate this + # dictionary telling us how many tokens we will need during tracing. + tokens: dict[Any, torch.Tensor] = field(default_factory=dict) + + # Only filled in if/when we trace the joint function + # If an input requires grad and is mutated in the backward, it is only safe to keep the mutation + # in the graph if gradients are disabled while the backward runs + # (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True) + # At runtime during the backward, we use this list of indices to error properly if we find out + # that it was not safe to include a backward mutation in the graph. + indices_of_inputs_that_requires_grad_with_mutations_in_bw: list[int] = field( + default_factory=list + ) + + # Indexes of saved tensors which are donated buffer. + # Donated buffer means the tensor is not alias of any forward user input, forward user output, + # and backward output. + bw_donated_idxs: Optional[list[int]] = None + + # Number of tokens used in backward, appended at the end of backward outputs. + # Filled after tracing joint function. + num_backward_tokens: int = 0 + + # Number of rng states that will get thread into the forward and backward for + # cudagraph compatible run_and_save_rng + num_graphsafe_rng_states: int = 0 + + graphsafe_rng_state_index: Optional[int] = None + + def __post_init__(self): + # pre-compute the indices of the inputs that are mutated. + # When keep_input_mutations is set, we don't need to worry about our epilogue + # handling data-only mutations, because we keep them directly in the graph. + mutated_inp_runtime_indices = [ + i + for i, m in enumerate(self.input_info) + if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH) + ] + + mutated_graph_handled_indices = [ + i + for i, m in enumerate(self.input_info) + if m.mutation_type == MutationType.MUTATED_IN_GRAPH + ] + self.mutated_graph_handled_indices = mutated_graph_handled_indices + self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices) + + mutated_graph_handled_indices_seen_by_autograd = [ + i + for i in mutated_graph_handled_indices + if not self.input_info[i].mutations_hidden_from_autograd + ] + + self.mutated_graph_handled_indices_seen_by_autograd = ( + mutated_graph_handled_indices_seen_by_autograd + ) + self.num_mutated_graph_handled_indices_seen_by_autograd = len( + self.mutated_graph_handled_indices_seen_by_autograd + ) + + aliased_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type + not in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + unsafe_view_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type is OutputType.unsafe_view_alias + ] + + # This is pre-computed in post_init for perf. + # It contains the index of every element + # of input_info that corresponds to a mutation (data or metadata or both) + self.mutated_inp_runtime_indices = mutated_inp_runtime_indices + self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices) + + # This is pre-computed for perf. + # It contains the index of every element + # of output_info that corresponds to an alias (either of an input or intermediate) + self.aliased_out_indices = aliased_out_indices + self.unsafe_view_out_indices = unsafe_view_out_indices + self.num_outputs = len(self.output_info) + self.num_outputs_non_aliased = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + ) + self.num_outputs_aliased_to_inputs = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_input, + OutputType.is_input, + ] + ] + ) + self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices) + self.num_outputs_aliased_to_intermediates = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ] + ] + ) + self.num_outputs_aliased = ( + self.num_outputs_aliased_to_inputs + + self.num_outputs_aliased_to_intermediates + ) + + # Record dynamic outputs of the Dynamo traced forward graph + # Mark them as dynamic at the end of the runtime wrapper + self.dynamic_outputs = any(o.dynamic_dims for o in self.output_info) + + # Record the indices of dynamic outputs in the partitioned forward graph + # Mark them as dynamic in the runtime wrapper + # activation index -> dynamic dims indices + self.dynamic_saved_tensors_idxs: dict[int, set[int]] = {} + + # See Note: [AOTAutograd Backward Guards] + # This is pre-computed for fast asserts on the types of our grad_outputs in the backward. + # Eventually, we should kill this and replace with real backward guards. + # (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor) + self.output_types = [ + torch.Tensor if isinstance(x, FakeTensor) else type(x) + for x in self.traced_tangents + ] + + self.is_rng_op_functionalized = config.functionalize_rng_ops + # All of the above metadata is collected by tracing the fw function. + # However, extra outputs for rng offsets behave differently. Both fwd + # and bwd graphs have their own outputs for the total consumed offsets. + # Unlike mutated inputs, we don't have to worry about sending the right + # set of tensors between fwd and bwd. Fwd and bwd offsets are + # independent and simpler to handle. Therefore, we track them + # separately. + self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0 + + # Our forward() returns both (tokens, mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints) + # Tokens will be split out before mutations/view handling and we do not count them here. + self.num_forward_returns = ( + self.num_mutated_inp_runtime_indices + + self.num_outputs + + self.num_intermediate_bases + ) + # In case of functionalization of rng ops, the fw_module returns one + # additional output for rng offset. This rng offset is used right + # away to advance the rng state, and is not passed on to the raw + # outputs. However, we need to know the exact boundary to identify + # which tensors to be saved for the bwd graph. num_forward captures + # this information. + self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset + + def make_runtime_safe(self): + """ + There are various fields in ViewAndMutationMeta that aren't serializable. This function is called after all tracing + is completed to simplify certain fields in the metadata so that they can be safely cached. + + Doing so may lose information (in the case of traced_tangents), but none of the information is needed at runtime. + """ + # TODO: This function is only a best effort: there are other fields that may not be cache safe + # (i.e., there's no guarantee that tensor_flatten() returns a serializable result), or that + # SubclassCreationMeta is cache safe. + assert self.traced_tangent_metas is None + + def extract_metadata(t): + if isinstance(t, torch.Tensor) and is_traceable_wrapper_subclass(t): + (inner_tensors, flatten_spec) = t.__tensor_flatten__() # type: ignore[attr-defined] + # Technically, we only need the flatten_spec, not the inner tensors. + # However, some Tensor subclasses (like TwoTensor) may have flatten_spec = None. + # And we want to be able to assert that this metadata is non-None, + # to distinguish between "this was a tensor subclass with no metadata" vs. + # "this wasn't a tensor subclass at all". + return (inner_tensors, flatten_spec) + else: + return None + + self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] + # Clear traced tangents at runtime + self.traced_tangents = [] + for inp_meta in self.subclass_inp_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + for inp_meta in self.subclass_fw_graph_out_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + for inp_meta in self.subclass_tangent_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + + @property + def tensors_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(self.num_forward, -self.num_symints_saved_for_bw) + else: + return slice(self.num_forward, None) + + @property + def symints_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(-self.num_symints_saved_for_bw, None) + else: + return slice(0, 0) # empty slice + + def __eq__(self, other): + if not isinstance(other, ViewAndMutationMeta): + return NotImplemented + return ( + self.input_info == other.input_info + and self.output_info == other.output_info + and self.num_intermediate_bases == other.num_intermediate_bases + and self.keep_input_mutations == other.keep_input_mutations + and self.is_rng_op_functionalized == other.is_rng_op_functionalized + and self.num_outputs_rng_offset == other.num_outputs_rng_offset + and len(self.traced_tangents) == len(other.traced_tangents) + and all( + x.shape == y.shape and x.dtype == y.dtype + for x, y in zip(self.traced_tangents, other.traced_tangents) + ) + and self.num_backward_tokens == other.num_backward_tokens + ) + + +@dataclass(eq=False) +class SubclassMeta: + # A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses) + # So for example, if the user had a model containing two `TwoTensor` inputs, + # Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here. + fw_metadata: ViewAndMutationMeta + + # Note: [Computing Subclass Metadata about grad_inputs] + # Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses + # + # You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs? + # (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous) + # + # This doesn't really work though. take this example: + # + # def f(DoubleTensor, DenseTensor): + # return DoubleTensor * DenseTensor + # + # In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor. + # When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs. + # This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input) + # and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors. + # + # Note that this info **cannot** easily be figured out from ViewAndMutationMeta. + # We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed. + # + # See Note: [AOTAutograd Backward Guards] + # This will also eventually require us to install backward guards, + # in case we made incorrect assumptions about the subclass-ness of our grad_outputs + # + # Optional field because we don't compute for inference graphs + grad_input_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = ( + None + ) + + def __init__(self) -> None: + # The fields in this class get set after its construction. + pass + + +# This class exists because: +# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs +# - we only care about the metadata on those aliases, so we can regenerate them. +# We do not want them to participate in the autograd.Function. +# We do that by wrapping them in an opaque class, so the autograd.Function +# does not know to treat them as tensors. +@dataclass(frozen=True) +class TensorAlias: + alias: torch.Tensor + + +@dataclass +class BackwardSignature: + """ + Provides information about the backward section of an exported + joint forward-backward graph. + For a particular fx GraphModule, this class contains information on: + (1) A mapping from each gradient (backwards output) to the parameter + it corresponds to (forward input) + (2) A mapping from each gradient (backwards output) to the user input + it corresponds to (forward input) + (3) Which of the forward outputs corresponds to the loss, that we backprop on. + + Each string name is the `node.name` of the corresponding node in the fx graph. + """ + + gradients_to_parameters: dict[str, str] + gradients_to_user_inputs: dict[str, str] + loss_output: str + + +GraphOutputName = NewType("GraphOutputName", str) +GraphInputName = NewType("GraphInputName", str) +FQN = NewType("FQN", str) + + +@dataclass +class GraphSignature: + """ + Provides information about an exported module. + For a particular fx GraphModule, this class contains information on: + (1) Which graph inputs are parameters, buffers, or user inputs + (2) (for params/buffers) a mapping from the name of each graph argument + to its parameter/buffer FQN in the original nn.Module. + (3) If there are input mutations, these are represented as extra outputs + in the fx GraphModule. We provide a mapping from these + extra output names to the names of the actual inputs. + (4) The pytree metadata on how to flatten/unflatten inputs and outputs. + The corresponding FX GraphModule only accepts and returns + pytree-flattened inputs/outputs. + (5) (Optionally) if the FX is a joint forward-backward graph, we provide + a signature on the backward section of the joint graph. + """ + + parameters: list[FQN] + buffers: list[FQN] + + user_inputs: list[GraphInputName] + user_outputs: list[GraphOutputName] + inputs_to_parameters: dict[GraphInputName, FQN] + inputs_to_buffers: dict[GraphInputName, FQN] + + # If the user's module mutates a buffer, + # it's represented in the graph as an extra graph output. + # This dict is a mapping from + # "graph outputs that correspond to updated buffers" + # to the FQN names of those mutated buffers. + buffers_to_mutate: dict[GraphOutputName, FQN] + parameters_to_mutate: dict[GraphOutputName, FQN] + user_inputs_to_mutate: dict[GraphOutputName, GraphInputName] + + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + backward_signature: Optional[BackwardSignature] + + input_tokens: list[GraphInputName] + output_tokens: list[GraphOutputName] + + @classmethod + def from_tracing_metadata( + cls, + *, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + graph_input_names: list[str], + graph_output_names: list[str], + view_mutation_metadata: ViewAndMutationMeta, + named_parameters: list[str], + named_buffers: list[str], + num_user_inputs: int, + num_user_outputs: int, + trace_joint: bool, + loss_index: Optional[int], + backward_signature: Optional[BackwardSignature], + ) -> GraphSignature: + graph_inputs = graph_input_names + graph_outputs = graph_output_names + parameters = list(named_parameters) + buffers = list(named_buffers) + num_tokens = len(view_mutation_metadata.tokens) + + # Calling convention assumptions: + # (1) graph inputs = (input_tokens, params, buffers, user_inputs) + # (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients) + # (If we are capturing an inference graph, this convention is identical + # except that param_gradients is empty) + # See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens + + # Address input calling conventions: + start, stop = 0, num_tokens + input_tokens = graph_inputs[start:stop] + + start, stop = stop, stop + len(parameters) + inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters)) + + start, stop = stop, stop + len(buffers) + inputs_to_buffers = dict( + zip( + graph_inputs[start:stop], + buffers, + ) + ) + + start, stop = stop, stop + num_user_inputs + user_inputs = graph_inputs[start:stop] + + # We should've gone through all the inputs now + assert len(graph_inputs) - stop == 0 + + # Address output calling conventions: + start, stop = 0, num_tokens + output_tokens = graph_outputs[start:stop] + + names = [*input_tokens, *parameters, *buffers, *user_inputs] + mutations = [] + for idx, input_info in enumerate(view_mutation_metadata.input_info): + if input_info.mutates_data: + if trace_joint: + # Only buffers can be mutated, not parameters + assert idx >= len(parameters) + mutations.append(names[idx + num_tokens]) + + assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices + + start, stop = ( + stop, + stop + view_mutation_metadata.num_mutated_inp_runtime_indices, + ) + outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations)) + + user_inputs_to_mutate = {} + buffers_to_mutate = {} + parameters_to_mutate = {} + for output_name, mutation_name in outputs_to_mutations.items(): + if mutation_name in user_inputs: + # pyrefly: ignore [unsupported-operation] + user_inputs_to_mutate[output_name] = mutation_name + else: + assert mutation_name in buffers or mutation_name in parameters + if mutation_name in buffers: + # pyrefly: ignore [unsupported-operation] + buffers_to_mutate[output_name] = mutation_name + else: + # pyrefly: ignore [unsupported-operation] + parameters_to_mutate[output_name] = mutation_name + + start, stop = stop, stop + num_user_outputs + user_outputs = graph_outputs[start:stop] + + unused_outputs = len(graph_outputs) - stop + if backward_signature is not None: + unused_outputs -= len(backward_signature.gradients_to_parameters) + len( + backward_signature.gradients_to_user_inputs + ) + assert unused_outputs == 0 + + return GraphSignature( + parameters=parameters, # type: ignore[arg-type] + buffers=buffers, # type: ignore[arg-type] + user_inputs=user_inputs, # type: ignore[arg-type] + user_outputs=user_outputs, # type: ignore[arg-type] + inputs_to_buffers=inputs_to_buffers, # type: ignore[arg-type] + inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type] + user_inputs_to_mutate=user_inputs_to_mutate, + buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type] + parameters_to_mutate=parameters_to_mutate, # type: ignore[arg-type] + in_spec=in_spec, + out_spec=out_spec, + backward_signature=backward_signature, + input_tokens=input_tokens, # type: ignore[arg-type] + output_tokens=output_tokens, # type: ignore[arg-type] + ) + + +@dataclass +class AOTAutogradCacheInfo: + cache_key: str + start_time_ns: int + forward_symints: list[torch.SymInt] + + +@dataclass +class AOTConfig: + """ + Configuration for AOTDispatcher + """ + + fw_compiler: Callable + bw_compiler: Callable + partition_fn: Callable + decompositions: dict[OpOverload, Callable] + num_params_buffers: int + aot_id: int + keep_inference_input_mutations: bool + is_export: bool = False + no_tangents: bool = False + dynamic_shapes: bool = False + aot_autograd_arg_pos_to_source: Optional[list[Source]] = None + static_input_indices: Optional[list[int]] = None + inference_compiler: Optional[Callable] = None + enable_log: bool = True + # this is always false outside of export. + pre_dispatch: bool = False + # Key to use for AOTAutogradCache + cache_info: Optional[AOTAutogradCacheInfo] = None + # If we should ignore the shape_env in the ambient tracing_context. + # The net effect is that if dynamic shapes are on, we end up + # specializing on example_inputs. + # Used only by standalone_compile. + ignore_shape_env: bool = False + precompile_backend_id: Optional[str] = None + force_non_lazy_backward_lowering: bool = False + # This config makes sure to check certain things like + # mutating input with req_grad in export joint tracing. + export_trace_joint: bool = False + disable_functionalization: bool = False + + def __post_init__(self): + if self.pre_dispatch: + assert self.is_export, "Can only have pre_dispatch IR for export." + + +# TODO: types here +# plain_tensor_trace_fn, when it is joint, has tuple structure on the trace +# info too! +# TODO: this needs to be generic, parameterized on AOTDescriptor +SubclassTracingInfo = collections.namedtuple( + "SubclassTracingInfo", + [ + "plain_tensor_trace_fn", + "plain_tensor_args", + "plain_tensor_args_descs", + "maybe_subclass_meta", + ], +) + + +@dataclass +class AOTState: + """ + When we run AOTAutograd, this class encapsulates the state in the compiler which + must be preserved across stages. This is state in the traditional sense (not an + environment) because some values in this structure change as we progress through + pipelines in AOTAutograd. + """ + + # Whether or not we need to handle autograd when doing graph capture and + # compilation. Although the calling convention for non-autograd graph + # capture in AOTAutograd is simple and can be relied upon, the autograph + # capture calling convention is quite complicated and in general you are + # only expected to pass to aot_stage2_compile to process. + needs_autograd: bool + + # The FAKE flat arguments which we will do tracing with. Although you + # might naively expect this to be immutable, it's not: when we perform + # tracing, we may execute code that modifies the metadata of inputs, + # causing the args to become "invalid". It's also nontrivial to have a + # "golden" set of fake values and deepcopy them just in time when you + # might destructively mutate them (Voz and I tried very hard to do this). + # So we just periodically renew this field. Don't worry too much about + # this unless you're specifically trying to track down an input metadata + # mutation bug. + # + # (By the way, this is NEVER the joint inputs! Those only ever go in + # AOTGraphCapture) + flat_args: list[FxValue] + + # The descriptor for each argument in flat_args. + flat_args_descs: list[AOTInput] + + # This contains view and mutation information about the function, which we + # detected by doing an initial trace when we created this state. + fw_metadata: ViewAndMutationMeta + + # Top-level configuration + # This is morally immutable but sometimes we are naughty and mutate it. + aot_config: AOTConfig + + # When performing AOTAutograd traces and other passes, we typically + # require a lot of active context managers; most typically these either + # (1) ensure we are faithfully replicating the original PyTorch context + # managers or (2) toggle some behaviors in PyTorch to make it more + # suitable for tracing. When you use AOTState, you're expected to have + # created an ExitStack, entered it; then while we are running AOTAutograd + # we will add things onto the stack as necessary. When you're all done + # with processing AOTAutograd, you can exit this stack. All functions + # that take AOTState expect the ExitStack to not have been exited yet. + # + # TODO: We potentially could offer a resumable context manager, where you + # can cancel it and reenable it later when you need it. + stack: contextlib.ExitStack + + +FxValue = Union[Tensor, int, SymInt, BackwardState] + + +class CompilerWrapper: + """ + AOTAutograd needs to do many transformations to the calling convention of the user function + it is tracing, e.g., deduplicating inputs, unpacking subclasses, etc. CompilerWrapper lets + us factor these into compositional stages so we can handle each transformation incrementally + instead of having to do it all at once. + + Since there is a calling convention change, there are two parts to the wrpaper: + + 1. The prologue, which is about compile-time behavior: given this original function, what + is the new function with modified calling convention that we should trace with AOTAutograd + to get the FX graph we will do joint passes, partitioning and ultimate Inductor compilation on? + We get (flat_fn, flat_args), the original function under trace and inputs we were + going to feed it, and produce a new function and new inputs to feed it. + + 2. The epilogue, which is about run-time behavior: we have now compiled the modified calling + convention function, we need to wrap it so that we have a new function that has the + original calling convention of the original function, so that our users can call it + at the old signature they expected. We get (compiled_fn, real arguments), the newly + compiled function we need to wrap. + + Note about caching: we do NOT directly serialize the runtime wrappers; instead, they + are reapplied to compiled_fn after we have finished deserializing the compiled_fn. + + Extra metadata that is needed to compute pre or post compile can be passed in via attributes. + """ + + def pre_compile( + self, + flat_fn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return flat_fn, flat_args, flat_args_descs, fw_metadata + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +class InductorWrapper: + """ + This is sort of like CompilerWrapper, but it happens at a different part of the lifecycle: + it talks about transformations we do to the traced and partitioned FX graph before we + send it to the Inductor compiler. + + Once again, there are two parts: + + 1. The prologue, which "modifies" the FX graph before we send it to + Inductor. I say "modifies" because... we don't really actually do + anything nontrivial in either of our two implementations. + 2. The epilogue, which modifies the compiled function produced by Inductor + + Although hypothetically these wrappers could be used compositionally in a centralized + wrappers list, in practice they seem to just be invoked manually when needed. + + NB: The flat_args input is sometimes mutated. This is probably naughty but whatever. + """ + + def pre_compile( + self, + fw_module: torch.fx.GraphModule, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> None: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +@dataclass +class AOTGraphCapture: # Produced by aot_stage1_graph_capture + # AOTAutograd typically operates by taking complicated graphs and + # desugaring them into simpler graphs that use PyTorch features. These + # wrappers establish invariants so that when we actually do tracing we can + # assume these invariants hold, leading to a simpler tracing + # implementation. However, this means that we have to keep track of how + # to enter/exit these wrappers when passing inputs into the compiled + # graph, among other things! + wrappers: list[CompilerWrapper] + + # The actual captured graph module. In some circumstances (export) this + # graph has a specific calling convention that can be relied upon by + # external callers. In other situations, the calling convention is + # unspecified and only aot_stage2_compile knows how to deal with them. + graph_module: torch.fx.GraphModule + + # When compiling with autograd support, this is the joint_inputs, which is + # larger than the original flat_args as all tangents get inputs. The + # tuple organizes into primals and tangents. When not autograd it's just + # a plain list. + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + + updated_flat_args_descs: Union[ + list[AOTInput], tuple[list[AOTInput], list[AOTInput]] + ] + + # Metadata about subclass inputs/outputs in the graph trace. + maybe_subclass_meta: Any + + +FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) + + +TOutputCode = TypeVar("TOutputCode", bound="OutputCode") + + +class AOTDispatchCompiler(Protocol): + """ + Represents a fw or bw_compiler passed to AOTAutograd. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> Any: ... + + +# TODO: bikeshed on this name +class SerializableAOTDispatchCompiler(AOTDispatchCompiler): + """ + Represents an AOTDispatchCompiler that returns an OutputCode, and is + therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. + A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of + the kwargs in _CompileFxKwargs. + """ + + def __init__( + self, + output_code_ty: type[TOutputCode], + compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], + ): + # pyrefly: ignore [invalid-type-var] + self.output_code_ty = output_code_ty + # pyrefly: ignore [invalid-type-var] + self.compiler_fn = compiler_fn + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> OutputCode: + return self.compiler_fn(gm, example_inputs) + + +class FlatFn(Protocol): + def __call__(self, *args: FxValue) -> list[FxValue]: ... + + +class TraceFn(Protocol): + def __call__(self, *args: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: ... + + +class PreppedForAutogradTraceFn(Protocol): + def __call__( + self, + *args: FxValue, + ) -> tuple[tuple[list[FxValue], list[bool]], list[AOTOutput]]: ... + + +class JointTraceFn(Protocol): + handle: JointFnHandle + + def __call__( + self, primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[Optional[Tensor]]], + tuple[list[AOTOutput], list[Optional[AOTOutput]]], + ]: ... + + +@dataclass +class JointWithDescriptors: + _aot_state: AOTState + _aot_graph_capture: AOTGraphCapture + + # The exact order parameters and buffers are expected to be passed into + # the final compiled function. Parameters before buffers. + params_spec: list[str] + buffers_spec: list[str] + + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + @property + def graph_module(self): + return self._aot_graph_capture.graph_module + + @graph_module.setter + def graph_module(self, value): + self._aot_graph_capture.graph_module = value diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/streams.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb76a637bf71ca8b813d68fcae3123159a21114 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/streams.py @@ -0,0 +1,281 @@ +from typing import Any, Optional, TypeAlias + +import torch.fx +import torch.fx.traceback +import torch.utils._pytree as pytree +from torch._dynamo.graph_utils import _get_flat_args +from torch._dynamo.variables.streams import get_current_stream, new_event +from torch.utils._runtime_estimation import ( + _FLOAT_TYPES, + _IGNORE_OPS, + get_compute_time, + get_transfer_time, +) + +from .indexed_dict import IndexedDict + + +Node: TypeAlias = torch.fx.Node +Graph: TypeAlias = torch.fx.Graph + + +def get_roofline_estimate(node: Node) -> float: + assert node.op == "call_function", "non-func node in roofline estimate" + + def map_value(x: Any) -> Any: + return x.meta.get("value", x) if isinstance(x, Node) else x + + func = node.target + if func in _IGNORE_OPS: + return 0.0 + + mapped_args = torch.fx.map_arg(node.args, map_value) + mapped_kwargs = torch.fx.map_arg(node.kwargs, map_value) + flat_args_kwargs = [map_value(x) for x in _get_flat_args(node, {})] + flat_outs, _ = pytree.tree_flatten(node.meta.get("value", node)) + out = node.meta.get("value", node) + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in _FLOAT_TYPES + } + + return ( + max( + get_transfer_time(flat_args_kwargs, flat_outs), + get_compute_time(func, mapped_args, mapped_kwargs, out, out_dtypes), + ) + / 1e6 + ) + + +def is_gradient_acc(node: Node) -> bool: + return node.meta.get("is_gradient_acc", False) + + +def is_bwd_node(node: Node) -> bool: + tag = node.meta.get("partitioner_tag") + return tag == "is_backward" or tag == "must_be_in_backward" + + +def get_device(node: Node) -> torch.device: + return node.meta["val"].device + + +def get_stream(node: Node) -> Optional[int]: + maybe_annotation = node.meta.get("custom", None) + if maybe_annotation is not None: + return node.meta["custom"].get("stream", None) + else: + return None + + +def get_stream_or_current_stream(node: Node) -> int: + ind = get_stream(node) + if ind is None: + ind = get_current_stream(get_device(node)) + return ind + + +def set_stream(node: Node, ind: int) -> None: + if "custom" in node.meta: + node.meta["custom"].update({"stream": ind}) + else: + node.meta["custom"] = {"stream": ind} + + +def insert_record_event_after_node(graph: Graph, node: Node, event_ind: int) -> Node: + with graph.inserting_after(node): + node = graph.call_function( + torch.ops.streams.record_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + return node + + +def insert_wait_event_before_node(graph: Graph, node: Node, event_ind: int) -> Node: + with graph.inserting_before(node): + node = graph.call_function( + torch.ops.streams.wait_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + return node + + +def populate_stream_timeline( + stream_to_timeline: dict[Optional[int], IndexedDict[Node, float]], + graph: Graph, + stream_index: Optional[int], +) -> IndexedDict[Node, float]: + if stream_index not in stream_to_timeline: + stream_to_timeline[stream_index] = IndexedDict() + total_time = 0.0 + for node in graph.nodes: + # mlazos: not sure if we should include forward here too but don't think it matters + if is_bwd_node(node) and get_stream(node) == stream_index: + total_time += get_roofline_estimate(node) + stream_to_timeline[stream_index][node] = ( + total_time # NB: total time includes the node's runtime + ) + + return stream_to_timeline[stream_index] + + +# NB: we start all estimates at 0, estimating the total runtime of each stream with timestamps at each node +# we then try and use these timestamps to estimate when to deallocate tensors used in side streams +# See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream +# for details on the problem being addressed. Rather than using the automatic memory management approach of record_stream +# we attempt to find the point which to deallocate based on the estimated timestamps. +def handle_synced_deallocation( + graph: Graph, + stream_to_exec_trace: dict[Optional[int], IndexedDict[Node, float]], + node: Node, + last_usage: Node, +) -> None: + assert is_bwd_node(node), ( + "synced allocations should only be handled on backward nodes" + ) + assert is_bwd_node(last_usage), ( + "synced allocations should only be handled on backward nodes" + ) + allocating_stream = get_stream(node) + side_stream = get_stream(last_usage) + assert allocating_stream != side_stream, ( + "allocating and side stream should be different for synced deallocations" + ) + if not torch.cuda.is_available(): + # fallback to record_stream in this case + with graph.inserting_after(node): + graph.call_function( + torch.ops.streams.record_stream.default, + ( + node, + get_stream_or_current_stream(last_usage), + ), + {}, + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + allocating_stream_trace = populate_stream_timeline( + stream_to_exec_trace, graph, allocating_stream + ) + side_stream_trace = populate_stream_timeline( + stream_to_exec_trace, graph, side_stream + ) + + alloc_ptr = node + target_side_stream_time = side_stream_trace[last_usage] + # linear search from first usage of tensor to a point in time after the side stream has finished + while alloc_ptr is not None: + alloc_time = allocating_stream_trace[alloc_ptr] + + if alloc_time >= target_side_stream_time: + break + elif alloc_time < target_side_stream_time: + next_ptr = allocating_stream_trace.next_key(alloc_ptr) + if next_ptr is not None: + alloc_ptr = next_ptr + else: + break + + wait_event = new_event() + record_node = insert_record_event_after_node(graph, last_usage, wait_event) + with graph.inserting_after(max(alloc_ptr, record_node)): + graph.call_function( + torch.ops.streams.sync_dealloc.default, + (wait_event, get_stream_or_current_stream(alloc_ptr), node), + {}, + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + +def insert_sync( + graph: Graph, + consumer: Node, + producer: Node, + node_to_wait_event_ind: dict[Node, int], +) -> None: + if producer not in node_to_wait_event_ind: + node_to_wait_event_ind[producer] = new_event() + + insert_record_event_after_node( + graph, producer, node_to_wait_event_ind[producer] + ) + insert_wait_event_before_node(graph, consumer, node_to_wait_event_ind[producer]) + + +def assign_backward_streams(gm: torch.fx.GraphModule) -> None: + """Assigns backward streams to gradient accumulation nodes""" + + # NB: iterate in reverse order to more closely match eager + # the user node stream will be populated first + for node in reversed(list(gm.graph.nodes)): + if is_gradient_acc(node): + # Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream: + # 1. Match first stream assignment of the first user with a stream + # 2. Match first stream assignment encountered in the args from left to right + # This differs from eager in some cases: + # Specifically the eager code uses the autograd node to determine the stream, + # crucially this does not necessarily correspond to the FX graph node. For example, + # in the backward for an add node with a constant we will passthrough and during backward tracing, + # no op will be added to the FX graph, so our stream assignment will differ in this case. + gradients = _get_flat_args(node, {}) + users = list(node.users.keys()) + + # All gradients will be on same device, they will be coerced if they were not with a .to() node + for neighbor in users + gradients: + ind = get_stream(neighbor) + if ind is not None: + set_stream(node, ind) + break + + +def insert_backward_syncs(gm: torch.fx.GraphModule) -> None: + """Inserts stream syncs for backward nodes if consumer and producer are on different streams""" + node_to_wait_event_ind: dict[Node, int] = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + flat_args = _get_flat_args(node, {}) + cur_node_stream = get_stream(node) + + for arg in flat_args: + if is_bwd_node(arg): + arg_stream = get_stream(arg) + if arg_stream != cur_node_stream and get_device(arg).type != "cpu": + insert_sync(gm.graph, node, arg, node_to_wait_event_ind) + + +def sync_deallocations(gm: torch.fx.GraphModule) -> None: + """Handles https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream""" + # Note: this is only needed if the last usage of a tensor is on a stream other than + # the stream the tensor was allocated on + + # an estimated timestamp from the beginning of graph execution (assuming 0 CPU overhead) + # I think this is fine because you should have large tensors if you're using streams + # although perhaps I could add a constant 10us per op ahead of the first stream op? + # a trace of all the nodes running in a given stream + stream_to_exec_trace: dict[Optional[int], IndexedDict[Node, float]] = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + allocating_stream = get_stream(node) + users = list(node.users.keys()) + if not users: + continue + last_user = max(user for user in users) + if last_user.op == "output": + continue + side_stream = get_stream(last_user) + if allocating_stream != side_stream: + handle_synced_deallocation( + gm.graph, stream_to_exec_trace, node, last_user + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea6635a62e81a57fba45e97d5f0eb2109e48d8f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py @@ -0,0 +1,104 @@ +import dataclasses +import itertools +from collections.abc import Iterable +from typing import Any, Union + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +# This is technically very similar to SubclassCreatingMeta +# in aot_autograd, but we don't need all the stuff in there +# so just recreated a new dataclass. +@dataclasses.dataclass +class SubclassCreationMeta: + start_idx: int + num_tensors: int + class_type: Any + attrs: dict[str, "SubclassCreationMeta"] + metadata: Any + outer_size: Iterable[Union[None, int, torch.SymInt]] + outer_stride: Iterable[Union[None, int, torch.SymInt]] + + +class UnwrapTensorSubclass(torch.nn.Module): + def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def] + todo: list[torch.Tensor] = list(tensors) + + def _unwrap_tensor_subclasses(subclass_meta, tensors, offset): # type: ignore[no-untyped-def] + if subclass_meta is None: + return tensors[offset], offset + 1 + inner_tensors = {} + for attr, meta in subclass_meta.attrs.items(): + built_tensor, offset = _unwrap_tensor_subclasses(meta, tensors, offset) + inner_tensors[attr] = built_tensor + rebuilt = subclass_meta.class_type.__tensor_unflatten__( + inner_tensors, + subclass_meta.metadata, + subclass_meta.outer_size, + subclass_meta.outer_stride, + ) + return rebuilt, offset + + return _unwrap_tensor_subclasses(self.subclass_meta, todo, 0)[0] + + def right_inverse(self, tensor: torch.Tensor) -> list[torch.Tensor]: + assert type(tensor) is not torch.Tensor + plain_tensors: list[torch.Tensor] = [] + + def _create_subclass_meta(tensor, idx, plain_tensor_container): # type: ignore[no-untyped-def] + if type(tensor) is torch.Tensor: + plain_tensor_container.append(tensor) + return None, idx + 1 + inner_tensors_attrnames, metadata = tensor.__tensor_flatten__() # type: ignore[attr-defined] + new_idx = idx + attr_to_meta = {} + for attr in inner_tensors_attrnames: + val = getattr(tensor, attr) + subclass_meta, new_idx = _create_subclass_meta( + val, new_idx, plain_tensor_container + ) + attr_to_meta[attr] = subclass_meta + return ( + SubclassCreationMeta( + start_idx=idx, + num_tensors=new_idx - idx, + class_type=type(tensor), + attrs=attr_to_meta, + metadata=metadata, + outer_size=tensor.size(), + outer_stride=tensor.stride(), + ), + new_idx, + ) + + self.subclass_meta = _create_subclass_meta(tensor, 0, plain_tensors)[0] + return plain_tensors + + +def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Module: + """ + Model transformation that replaces all the parameters that are subclasses to plain tensors. + This reduces runtime overhead of flattening/unflattening the parameters. + + This transformation adds parametrization with `torch.nn.utils.parametrize`. + The FQNs of the subclass parameters will be changed and state_dict will become incompatible with the original model. + E.g. + Original model state_dict: {"p1": torch.testing._internal.TwoTensor} + becomes: {"parametrizations.p2.original0": torch.Tensor, "parametrizations.p2.original1": torch.Tensor} + + """ + for name, tensor in itertools.chain( + list(module.named_parameters(recurse=False)), + # pyrefly: ignore [no-matching-overload] + list(module.named_buffers(recurse=False)), + ): + if is_traceable_wrapper_subclass(tensor): + torch.nn.utils.parametrize.register_parametrization( + module, name, UnwrapTensorSubclass() + ) + + for child in module.children(): + unwrap_tensor_subclass_parameters(child) + + return module diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a579888dfade33b49ba6f24d1542bcc24a082f29 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py @@ -0,0 +1,520 @@ +# mypy: allow-untyped-defs +""" +This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes. +AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher, +and this includes tensor subclasses that implement __torch_dispatch__. +""" + +import collections +import typing +from collections.abc import Callable, Iterable +from typing import Any, Optional, TypeGuard, TypeVar, Union + +import torch +import torch.utils._pytree as pytree +from torch import SymInt, Tensor +from torch._subclasses.fake_tensor import get_plain_tensors +from torch.types import IntLikeType +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .descriptors import ( + AOTInput, + AOTOutput, + DummyAOTInput, + SubclassGetAttrAOTInput, + SubclassGetAttrAOTOutput, + SubclassSizeAOTInput, + SubclassSizeAOTOutput, + SubclassStrideAOTInput, + SubclassStrideAOTOutput, +) +from .schemas import ( + FxValue, + MutationType, + PlainTensorMeta, + SubclassCreationMeta, + ViewAndMutationMeta, +) +from .utils import strict_zip + + +zip = strict_zip + +T = TypeVar("T", bound=torch.Tensor) + + +def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: + args_flattened = pytree.arg_tree_leaves(*args) + any_subclass_args = any( + is_traceable_wrapper_subclass(x) + for x in args_flattened + if isinstance(x, Tensor) + ) + from torch._functorch._aot_autograd.schemas import SubclassCreationMeta + + any_subclass_outputs = any( + type(x) is SubclassCreationMeta for x in fw_metadata.subclass_fw_graph_out_meta + ) + # This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime. + return any_subclass_args or any_subclass_outputs + + +from .schemas import MemoryFormatMeta + + +def maybe_suggest_memory_format( + t, with_memory_format: bool +) -> Optional[MemoryFormatMeta]: + if not with_memory_format: + return None + + return MemoryFormatMeta.from_tensor(t) + + +def get_subclass_typing_container( + tensor_subclass: torch.Tensor, +) -> dict[type[torch.Tensor], list[type[torch.Tensor]]]: + """ + Given a subclass, returns a recursive dictionary mapping each + inner tensors to its' subclass types. + """ + + def _get_types_for_subclass(tensor_subclass: torch.Tensor) -> None: + if not is_traceable_wrapper_subclass(tensor_subclass): + return + tracker[type(tensor_subclass)].append(tensor_subclass) + inner_keys, _ = tensor_subclass.__tensor_flatten__() + for key in inner_keys: + inner_tensor = getattr(tensor_subclass, key) + _get_types_for_subclass(inner_tensor) + + tracker: dict[Any, list[Any]] = collections.defaultdict(list) + _get_types_for_subclass(tensor_subclass) + return tracker + + +def create_subclass_metadata( + a: Any, start_idx: int, count_symints: bool, with_memory_format: bool = False +): + if not is_traceable_wrapper_subclass(a): + idx = start_idx + 1 + return ( + PlainTensorMeta( + idx, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ), + idx, + ) + + inner_keys, metadata = a.__tensor_flatten__() + new_start_idx = start_idx + attrs = {} + + for key in inner_keys: + new_subclass_meta, new_start_idx = create_subclass_metadata( + getattr(a, key), + new_start_idx, + count_symints=count_symints, + with_memory_format=with_memory_format, + ) + attrs[key] = new_subclass_meta + + # It *must* be because is_traceable_wrapper_subclass() - but mypy is not smart. + assert isinstance(a, Tensor) + + new_start_idx = ( + new_start_idx + + count_symints * len(enumerate_filter_symints(a.size())) + + count_symints * len(enumerate_filter_symints(a.stride())) + ) + + return ( + SubclassCreationMeta( + flat_tensor_start_idx=start_idx, + arg_count=new_start_idx - start_idx, + included_subclass_symints=count_symints, + attrs=attrs, + meta=metadata, + outer_size=a.size(), # type: ignore[attr-defined, arg-type] + outer_stride=a.stride(), # type: ignore[arg-type] + original_subclass=a, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ), + new_start_idx, + ) + + +# Given a flat list of arguments, some of which may be tensor subclasses, +# computes metadata about "how to reconstruct the current list of subclasses, +# if we were given their flattened dense tensors instead" +def create_subclass_meta( + curr_args: Union[list[Any], tuple[Any, ...]], + *, + count_symints: bool = True, + with_memory_format: bool = False, +) -> list[Union[PlainTensorMeta, SubclassCreationMeta]]: + idx = 0 + infos: list[Union[PlainTensorMeta, SubclassCreationMeta]] = [] + for a in curr_args: + if is_traceable_wrapper_subclass(a): + assert isinstance(a, Tensor) + start_idx = idx + subclass_meta, _ = create_subclass_metadata( + a, + start_idx, + count_symints=count_symints, + with_memory_format=with_memory_format, + ) + infos.append(subclass_meta) + cnt = subclass_meta.arg_count + else: + infos.append( + PlainTensorMeta( + idx, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ) + ) + cnt = 1 + idx += cnt + return infos + + +def enumerate_filter_symints(lst: Iterable[IntLikeType]) -> list[tuple[int, SymInt]]: + # Capture all SymInts from the iterable. + def symint_check(s: IntLikeType) -> TypeGuard[SymInt]: + return isinstance(s, SymInt) and not s.node.is_nested_int() + + return [(i, s) for i, s in enumerate(lst) if symint_check(s)] + + +def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list[bool]: + # Non-nested symints are replaced with None in `make_runtime_safe()` + return [s is None for s in lst] + + +# Intended to make it easier to define function that is +# either (AOTInput -> AOTInput) or (AOTOutput -> AOTOutput) +# but not the other combos +AOTDescriptor = TypeVar("AOTDescriptor", AOTInput, AOTOutput) + + +# This function takes in a pytree of arguments and unwraps any tensor +# subclasses. +# +# NOTE: The reason for "append_symints": +# +# * At compile time: we append extra symint args when unwrapping primals +# (but not tangents, because they should always share symints with primals). +# We also append extra symints when unwrapping the subclass outputs of the +# traced function, so we can return them as extra outputs +# +# * At runtime: we similarly append subclass sizes when we unwrap subclass +# primals (but not tangents) on entry to the forward. See the runtime version of +# this function below. +def unwrap_tensor_subclasses( + wrapped_args: list[FxValue], + wrapped_args_descs: list[AOTDescriptor], + *, + append_symints: bool, +) -> tuple[list[FxValue], list[AOTDescriptor]]: + def flatten_subclass( + t: FxValue, + desc: AOTDescriptor, + *, + out: tuple[list[FxValue], list[AOTDescriptor]], + ): + # unwrap a subclass into plain tensors and their size/stride if "append_symint" + # is True + if not is_traceable_wrapper_subclass(t): + out[0].append(t) + out[1].append(desc) + return + + attrs, _ = t.__tensor_flatten__() + + for attr in attrs: + inner_tensor = getattr(t, attr) + n_desc: Any = ( + SubclassGetAttrAOTInput(desc, attr) + if isinstance(desc, AOTInput) + # pyrefly: ignore [bad-argument-type] + else SubclassGetAttrAOTOutput(desc, attr) + ) + flatten_subclass(inner_tensor, n_desc, out=out) + + if append_symints: + sizes = enumerate_filter_symints(t.size()) + strides = enumerate_filter_symints(t.stride()) + out[0].extend(s for _, s in sizes) + out[0].extend(s for _, s in strides) + if isinstance(desc, AOTInput): + out[1].extend(SubclassSizeAOTInput(desc, i) for i, _ in sizes) # type: ignore[misc] + out[1].extend(SubclassStrideAOTInput(desc, i) for i, _ in strides) # type: ignore[misc] + else: + out[1].extend(SubclassSizeAOTOutput(desc, i) for i, _ in sizes) # type: ignore[misc] + out[1].extend(SubclassStrideAOTOutput(desc, i) for i, _ in strides) # type: ignore[misc] + + xs_inner: list[FxValue] = [] + descs_inner: list[AOTDescriptor] = [] + + for x, desc in zip(wrapped_args, wrapped_args_descs): + # pyrefly: ignore [bad-argument-type] + flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner)) + + return xs_inner, descs_inner + + +# subclass_metas is needed at runtime to compute which indices are symints in +# the outer_size/outer_stride +def runtime_unwrap_tensor_subclasses( + wrapped_args: list[Union[Tensor, int]], + *, + append_symints: bool, + subclass_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = None, +): + def flatten_subclass(x: Tensor, meta: Optional[SubclassCreationMeta], *, out): + if not is_traceable_wrapper_subclass(x): + out.append(x) + return out + + assert isinstance(x, Tensor) + + attrs, _ = x.__tensor_flatten__() + + for attr in attrs: + inner_tensor = getattr(x, attr) + # pyrefly: ignore [missing-attribute] + inner_meta = meta.attrs.get(attr) + flatten_subclass(inner_tensor, inner_meta, out=out) + + if append_symints: + assert isinstance(meta, SubclassCreationMeta) + # outer_size + size = x.size() + symint_placeholders = compute_symint_placeholders(meta.outer_size) + assert len(size) == len(symint_placeholders) + out.extend( + [r for (r, is_symint) in zip(size, symint_placeholders) if is_symint] + ) + + # outer_stride + stride = x.stride() + symint_placeholders = compute_symint_placeholders(meta.outer_stride) + assert len(stride) == len(symint_placeholders) + out.extend( + [r for (r, is_symint) in zip(stride, symint_placeholders) if is_symint] + ) + return out + + xs_inner: list[Union[int, Tensor, SymInt]] = [] + + if append_symints: + assert subclass_metas is not None + + for idx, x in enumerate(wrapped_args): + if not is_traceable_wrapper_subclass(x): + xs_inner.append(x) + continue + + if subclass_metas is None: + get_plain_tensors(typing.cast(Tensor, x), out=xs_inner) + else: + meta = subclass_metas[idx] + assert isinstance(meta, SubclassCreationMeta) + flatten_subclass(typing.cast(Tensor, x), meta, out=xs_inner) + + return xs_inner + + +def unwrap_tensor_subclasses_with_indices_to_original(wrapped_args): + ret_unwrapped = [] + ret_indices_to_original = [] + for i, a in enumerate(wrapped_args): + a_unwrapped, _ = unwrap_tensor_subclasses( + [a], [DummyAOTInput(9999)], append_symints=False + ) + ret_unwrapped.extend(a_unwrapped) + n = len(a_unwrapped) + ret_indices_to_original.extend([i] * n) + + return ret_unwrapped, ret_indices_to_original + + +def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices): + static_input_indices = set(static_input_indices) + new_ind = 0 + remapped_static_indices = [] + for i, arg in enumerate(wrapped_args): + num_indices = 1 + if is_traceable_wrapper_subclass(arg): + num_indices = ( + len(get_plain_tensors(typing.cast(Tensor, arg), out=[])) + + len(enumerate_filter_symints(arg.size())) + + len(enumerate_filter_symints(arg.stride())) + ) + + for _ in range(num_indices): + if i in static_input_indices: + remapped_static_indices.append(new_ind) + + new_ind += 1 + + return remapped_static_indices + + +# Turns a flattened list of tensor arguments into (maybe) subclass tensors. +# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in. +def wrap_tensor_subclasses( + unwrapped_args: Union[tuple[Any, ...], list[Any]], + *, + subclass_metas: list[Union[PlainTensorMeta, SubclassCreationMeta]], + num_fw_outs_saved_for_bw: Optional[int] = None, + included_subclass_symints: bool = False, + is_runtime: bool = False, + make_subclass_override: Optional[Callable] = None, +) -> tuple[Any, ...]: + wrapped_args = [] + num_args_tallied = 0 + for subclass_meta in subclass_metas: + if isinstance(subclass_meta, PlainTensorMeta): + wrapped_args.append(unwrapped_args[subclass_meta.unwrapped_idx]) + num_args_tallied += 1 + else: + assert isinstance(subclass_meta, SubclassCreationMeta) + assert subclass_meta.included_subclass_symints == included_subclass_symints + + if make_subclass_override: + wrapped_args.append( + make_subclass_override(subclass_meta, is_runtime, unwrapped_args) + ) + else: + wrapped_args.append( + subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + ) + num_args_tallied += subclass_meta.arg_count + + # Note: [Partitioner handling for Subclasses, Part 2] + # At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw, + # to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them. + # + # When this function is called at runtime in the forward, + # we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs. + # + # One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen? + # Answer: we do it **inside of our compiled autograd.Function**. + # This seems like morally the right place: autograd happens above subclass desugaring, + # so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors. + # + # This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph + # into a forward and backward graph, we end up with some activations that show up as extra outputs + # in the compiled forward graph, that are **not** user outputs. + # These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses. + # + # On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`), + # we computed subclass metadata on every forward output, but this did **not** include activations + # created by the partitioner. + # as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations), + # but `subclass_metas` will only correspond to subclass metadata on `user_fw_outs`. + # We then need to make sure that we return (*wrapped_user_fw_outs, *activations). + if num_fw_outs_saved_for_bw is not None: + assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, ( + f"Expected the number actual unwrapped-subclass outputs {len(unwrapped_args)} to equal " + f"the number of args calculated from subclasses ({num_args_tallied}) plus the number of " + f"additional activations saved for the backward pass ({num_fw_outs_saved_for_bw})" + ) + activations = unwrapped_args[num_args_tallied:] + if isinstance(wrapped_args, tuple) and isinstance(activations, tuple): + return wrapped_args + activations + return tuple(list(wrapped_args) + list(activations)) + else: + assert len(unwrapped_args) == num_args_tallied, ( + f"Expected {len(unwrapped_args)} == {num_args_tallied}" + ) + return tuple(wrapped_args) + + +# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses. +# This function carefully handles the inference vs. joint cases: +# - when is_joint_structure is True, args is (primals, tangents) +# - when is_joint_structure is False, args is [*primals] +def wrap_tensor_subclasses_maybe_joint( + unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta +) -> Union[tuple[Any, ...], list[Any]]: + # Since this function is reused for both inference and joint graphs, + if is_joint_structure: + assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2 + assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance( + unwrapped_args[1], (tuple, list) + ) + primals, tangents = unwrapped_args[0], unwrapped_args[1] + wrapped_primals = wrap_tensor_subclasses( + primals, + subclass_metas=meta.subclass_inp_meta, + included_subclass_symints=True, + ) + wrapped_tangents = wrap_tensor_subclasses( + tangents, + subclass_metas=meta.subclass_tangent_meta, + included_subclass_symints=False, + ) + return (wrapped_primals, wrapped_tangents) + else: + wrapped_args = wrap_tensor_subclasses( + unwrapped_args, + subclass_metas=meta.subclass_inp_meta, + included_subclass_symints=True, + ) + return wrapped_args + + +def compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata: ViewAndMutationMeta, + inner_metadata: ViewAndMutationMeta, +) -> list[int]: + # Note: [Recomputing subclass mutation handling] + # + # Generally, if a subclass requires grad, its components will not require grad. + # But for the purposes of tracking returned tensors, we should treat those component + # tensors as if they require grad. + # + # For example, if the subclass tensor requires grad and will be mutated in a way that + # requires us to handle the mutation outside of the graph, we need to return it + # from the forward graph. The inner_meta data won't consider the component tensors + # as if they need to be returned, because they don't require grad; but really, we + # should handle those tensors the same way we handle the subclass tensor itself; i.e. + # if we'd include the subclass tensor as part of the outputs, then we should also + # include the component tensors. + # + # To do this, we patch num_mutated_inp_runtime_indices below by expanding the inputs + # from the outer subclass tensors and propagating + + updated_input_info = [] + inner_idx = 0 + if not fw_metadata.subclass_inp_meta: + # Sometimes we don't have subclass info, e.g. synthetic_base codepaths + return inner_metadata.mutated_inp_runtime_indices + assert len(fw_metadata.subclass_inp_meta) == len(fw_metadata.input_info) + for outer_idx, inp_meta in enumerate(fw_metadata.subclass_inp_meta): + if isinstance(inp_meta, PlainTensorMeta): + assert outer_idx < len(fw_metadata.input_info) + if inner_metadata is not None: + assert inner_idx < len(inner_metadata.input_info) + assert ( + inner_metadata.input_info[inner_idx] + == fw_metadata.input_info[outer_idx] + ) + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + else: + assert inp_meta.original_subclass is not None + for _ in range(inp_meta.arg_count): + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + if inner_metadata is not None: + assert len(inner_metadata.input_info) == len(updated_input_info) + + return [ + i + for i, inp in enumerate(updated_input_info) + if inp.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1255a6de8bf6e8f2d695c12c464be9c58aa171f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py @@ -0,0 +1,771 @@ +# mypy: allow-untyped-defs +""" +Contains various utils for AOTAutograd, including those for handling collections. +""" + +import copy +import dataclasses +import logging +import operator +import warnings +from collections.abc import Callable +from contextlib import nullcontext +from functools import wraps +from typing import Any, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch.utils._pytree as pytree +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import getArtifactLogger +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import py_sym_types + +from .descriptors import AOTOutput + + +KNOWN_TYPES = [ + torch.Tensor, + BackwardState, + int, + str, + float, + bool, + type(None), + *py_sym_types, + FakeScriptObject, + torch.ScriptObject, +] + +original_zip = zip + +aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects") +annotation_log = getArtifactLogger(__name__, "annotation") + + +def strict_zip(*iterables, strict=True, **kwargs): + if not strict: + return original_zip(*iterables, **kwargs) + + length = len(iterables[0]) + for iterable in iterables[1:]: + if len(iterable) != length: + raise ValueError( + "The iterables have different lengths and strict mode is enabled." + ) + + return original_zip(*iterables, **kwargs) + + +def _get_symint_hints(exprs): + """ + Get the hints of a list/tuple of int/SymInt. + """ + if isinstance(exprs, (list, tuple)): + return type(exprs)(_get_symint_hints(e) for e in exprs) + elif isinstance(exprs, torch.SymInt): + return exprs.node.shape_env.size_hint(exprs.node.expr) + else: + return exprs + + +def partial_flatten_asdict(obj: Any) -> Any: + if dataclasses.is_dataclass(obj): + return { + field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) + } + elif isinstance(obj, (list, tuple)): + return obj.__class__([partial_flatten_asdict(item) for item in obj]) + elif isinstance(obj, dict): + return {k: partial_flatten_asdict(v) for k, v in obj.items()} + else: + return obj + + +def normalize_as_list(x): + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + +def _get_autocast_states(): + return [ + torch.is_autocast_enabled("cuda"), + torch.is_autocast_enabled("cpu"), + torch.get_autocast_dtype("cuda"), + torch.get_autocast_dtype("cpu"), + torch.is_autocast_cache_enabled(), + ] + + +def make_boxed_func(f): + @simple_wraps(f) + def g(args): + return f(*args) + + g._boxed_call = True # type: ignore[attr-defined] + return g + + +def make_boxed_compiler(compiler): + @wraps(compiler) + def f(fx_g, inps): + out_f = compiler(fx_g, inps) + fx_g = make_boxed_func(out_f) + return fx_g + + return f + + +def call_func_at_runtime_with_args( + f, args: Union[tuple[Any], list[Any]], steal_args=False, disable_amp=False +): + if not steal_args: + args = list(args) + assert isinstance(args, list) + + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(): + if getattr(f, "_boxed_call", False): + out = normalize_as_list(f(args)) + else: + # TODO: Please remove soon + # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 + warnings.warn( + "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. " + "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " + "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.", + stacklevel=2, + ) + out = normalize_as_list(f(*args)) + return out + + +# Inspired by autodidax (thanks!) +class PytreeThunk: + spec: Optional[pytree.TreeSpec] = None + # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. + is_simple: Optional[bool] = ( + None # if the output spec is a tuple/list, we won't bother unflattening it. + ) + is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec + + def set(self, spec: pytree.TreeSpec) -> None: + assert self.spec is None or self.spec == spec + assert spec is not None + self.spec: pytree.TreeSpec = spec + if self.spec.type in {tuple, list} and all( + child.is_leaf() for child in spec.children() + ): + self.is_simple = True + if self.spec.is_leaf(): + self.is_really_simple = True + + def unflatten(self, x: list[Any]) -> Any: + if self.is_really_simple: + return x[0] + if self.is_simple: + return x + assert self.spec is not None + return pytree.tree_unflatten(x, self.spec) + + +# Creates a function that returns flattened inputs and outputs +# Also returns the output tree spec, which is needed to recover the "unflattened" +# output tree structure later. +def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThunk]: + if kwargs is None: + kwargs = {} + # Save the args_spec for flat_tensor_args to unflatten while tracing + _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) + out_spec = PytreeThunk() + + def flat_fn(*flat_args): + # The input are flattened tensor args. Prepare the args in the + # order that original function expects. Add static args as well. + # They will appear as tensor constants in the traced graph. + nonlocal out_spec + args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec) + tree_out = fn(*args, **kwargs) + flat_out, spec = pytree.tree_flatten(tree_out) + for i in flat_out: + is_known_type = False + for j in KNOWN_TYPES: + if isinstance(i, j): + is_known_type = True + break + if not is_known_type: + raise RuntimeError( + f"Found {type(i)} in output, which is not a known type. " + "If this type holds tensors, you need to register a pytree for it. " + "See https://github.com/pytorch/functorch/issues/475 for a brief " + "explanation why. If you don't need to register a pytree, please " + "leave a comment explaining your use case and we'll make this more " + "ergonomic to deal with" + ) + out_spec.set(spec) + return flat_out + + # Can't use functools.wraps here because the wrapper has different + # calling convention + if hasattr(fn, "_orig_mod"): + flat_fn._orig_mod = fn._orig_mod # type: ignore[attr-defined] + + return flat_fn, out_spec + + +# This function takes in a tensor t, and returns one of t, t.view(), or t.clone(). +# When tracing the joint forward + backward, for any inputs in the graph that are mutated, +# we need to clone them first (and similarly for metadata-only mutations, we need to view them first). +# The idea is that when we trace the backward, we need to pass in the *original* primals +# to autograd.grad(), before they were mutated. +# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them. +# This means that "idx" here represents the index of the (potentially) synthetic base. +# What we need to do is: +# (1) map the current (post-synthetic-base calling convention) input argument index +# to int index pre-synthetic-base-calling-convention. +# (2) There could be multiple, if this index corresponds to a synthetic base +# that has multiple input aliases. +# (3) If any of those corresponding inputs get metadata mutations, then we clone the base. +def maybe_to_fresh_input(idx, t, meta): + if not isinstance(t, torch.Tensor): + return t + if idx in meta.mutated_inp_runtime_indices: + # We only need to bother cloning mutated inputs that participate in autograd. + if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the mutation + return t.clone() + if meta.input_info[idx] and meta.input_info[idx].mutates_metadata: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the metadata mutation + return t.view(t.shape) + return t + + +def is_with_effects(node): + if ( + node.op == "call_function" + and node.target is torch.ops.higher_order.with_effects + ): + return True + elif ( + node.op == "call_function" + and node.target is torch.ops.higher_order.invoke_subgraph + ): + # Check if subgraph has effects by looking in the cache + from torch._guards import InvokeSubgraphCache, TracingContext + + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + effects = invoke_subgraph_cache.get_effects(node.args[1]) + return effects is not None + return False + + +def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None): + # Remove the tokens from the inputs/outputs of the graph since inductor does + # not want these extra inputs/outputs, and replace them with + # _make_token() to create a token, and _sink_tokens() to collect the + # tokens. See Note [Side-Effectful Tokens in AOTAutograd] + # Logic: + # 1. In the case of with_effects: + # Before: + # ``` + # def forward(self, token, arg1_1): + # with_effects = torch.ops.higher_order.with_effects(token, ...) + # getitem = with_effects[0] + # getitem_1 = with_effects[0] + # return (getitem, getitem_1) + # ``` + # + # After: + # ``` + # def forward(self, arg1_1): + # _make_token_default = torch.ops.prims._make_token.default() + # with_effects = torch.ops.higher_order.with_effects(_make_token_default, ...) + # getitem = with_effects[0] + # getitem_1 = with_effects[0] + # _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]); + # return (getitem_1,) + # ``` + # + # 2. In the case of an invoke_subgraph node, we will use the + # InvokeSubgraphCache to determine if the subgraph has effects. Then we will + # turn it into a `with_effects` node. This is so that at the toplevel graph, + # the nodes will have the correct with_effects threading. We will apply this + # pass recursively to submodules so the tokens will be removed from the + # subgraph's inputs. + # + # Before: + # ``` + # def forward(self, token, arg1_1): + # repeated_subgraph0 = self.repeated_subgraph0 + # invoke_subgraph = torch.ops.higher_order.invoke_subgraph( + # repeated_subgraph0, 'subgraph_0', token, x, arg1_1) + # getitem = invoke_subgraph[0] + # getitem_1 = invoke_subgraph[1] + # return (getitem, getitem1) + # ``` + # + # After: + # ``` + # def forward(self, arg1_1): + # _make_token_default = torch.ops.prims._make_token.default() + # repeated_subgraph0 = self.repeated_subgraph0 + # with_effects_1 = torch.ops.higher_order.with_effects( + # _make_token_default, torch.ops.higher_order.invoke_subgraph, + # repeated_subgraph0, 'subgraph_0', arg1_1) + # getitem = with_effects_1[0] + # getitem_1 = with_effects_1[1]; with_effects_1 = None + # _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]) + # return (getitem_1,) + # ``` + # + # 3. The toplevel module should have the following invariants: + # forward: + # expected_num_erased_inputs == len(fw_metadata.tokens) + # expected_num_erased_outputs == len(fw_metadata.tokens) + # backward: + # expected_num_erased_inputs == fw_metadata.num_backward_tokens + # expected_num_erased_outputs == fw_metadata.num_backward_tokens + num_forward_tokens = len(fw_metadata.tokens) + num_backward_tokens = fw_metadata.num_backward_tokens + + def replace_input_token_with_make_token(module, node): + with module.graph.inserting_before(node): + new_token_node = module.graph.call_function( + torch.ops.prims._make_token.default, () + ) + new_token_node.meta["val"] = torch.tensor([]) + new_token_node.meta["tensor_meta"] = torch.tensor([]) + node.replace_all_uses_with(new_token_node) + module.graph.erase_node(node) + + def get_output_tokens(node: torch.fx.Node) -> set[torch.fx.Node]: + output_tokens = set() + for user in list(node.users.keys()): + # Check if this is a getitem accessing index 0 (the token) + if ( + user.op == "call_function" + and user.target is operator.getitem + and len(user.args) > 1 + and user.args[1] == 0 + ): + # Check if this getitem is used in an output + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.add(user) + return output_tokens + + def _unlift_tokens_from_module_helper( + module: torch.fx.GraphModule, + subgraph_str: str, + expected_num_erased: Optional[int], + ): + input_token_nodes = set() + output_token_nodes = set() + + for node in module.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.higher_order.with_effects + ): + if node.args[0].op == "placeholder": + input_token_nodes.add(node.args[0]) + replace_input_token_with_make_token(module, node.args[0]) + + tokens_from_with_effects = get_output_tokens(node) + output_token_nodes = output_token_nodes | tokens_from_with_effects + + elif ( + node.op == "call_function" + and node.target is torch.ops.higher_order.invoke_subgraph + ): + subgraph_node, identifier, *operands = node.args + + # Check if subgraph has effects by looking in the cache + from torch._guards import InvokeSubgraphCache, TracingContext + + effects = None + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = ( + tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + effects = invoke_subgraph_cache.get_effects(identifier) + + if effects is not None: + # Wrap invoke_subgraph with with_effects + # Before: invoke_subgraph(subgraph, id, token, *args) -> (token_out, result) + # After: with_effects(token, invoke_subgraph, subgraph, id, *args) -> (token_out, result) + # + # Note: The subgraph itself will be unlifted separately when we iterate + # through named_modules() below. + + num_tokens = len(effects) + assert num_tokens == 1, "Multiple token subgraph NYI" + token_args = operands[:num_tokens] + non_token_args = operands[num_tokens:] + + # Create with_effects wrapper around invoke_subgraph + # with_effects(token, op, *args) where op is invoke_subgraph + # Pass the subgraph and non-token args to invoke_subgraph + with module.graph.inserting_before(node): + new_node = module.graph.call_function( + torch.ops.higher_order.with_effects, + ( + token_args[0], # pyrefly: ignore[bad-argument-type] + torch.ops.higher_order.invoke_subgraph, + subgraph_node, + identifier, + *tuple(non_token_args), + ), + ) + node.replace_all_uses_with(new_node) + new_node.meta = node.meta + module.graph.erase_node(node) + + for token in token_args: + if token.op == "placeholder": + input_token_nodes.add(token) + replace_input_token_with_make_token(module, token) + + # Get output tokens from the new with_effects node + tokens_from_invoke_subgraph = get_output_tokens(new_node) + output_token_nodes = ( + output_token_nodes | tokens_from_invoke_subgraph + ) + + output_node = next(reversed(module.graph.find_nodes(op="output"))) + assert output_node is not None + with module.graph.inserting_before(output_node): + module.graph.call_function( + torch.ops.prims._sink_tokens.default, + (list(output_token_nodes),), + ) + new_out_args = tuple( + [out for out in output_node.args[0] if out not in output_token_nodes] + ) + output_node.args = (new_out_args,) + + if expected_num_erased: + assert len(input_token_nodes) == expected_num_erased, ( + f"{subgraph_str} num_erased_inputs:{len(input_token_nodes)} " + f"{input_token_nodes} != expected {expected_num_erased} \n" + f"{fw_module.print_readable(print_output=False)}" + ) + assert len(output_token_nodes) == expected_num_erased, ( + f"{subgraph_str} num_erased_outs:{len(output_token_nodes)} " + f"{output_token_nodes} != expected {expected_num_erased} \n" + f"{fw_module.print_readable(print_output=False)}" + ) + + module.recompile() + + def unlift_tokens_from_module(module, subgraph_str, expected_num_erased): + for name, m in module.named_modules(): + if isinstance(m, torch.fx.GraphModule): + if name == "": + _unlift_tokens_from_module_helper( + m, subgraph_str, expected_num_erased + ) + else: + # Subgraph -- we may or may not have effects applied + _unlift_tokens_from_module_helper(m, f"{subgraph_str}_{name}", None) + + if num_forward_tokens > 0: + if aot_config.enable_log: + from torch._dynamo.utils import lazy_format_graph_code + + aot_graphs_effects_log.debug( + "%s", + lazy_format_graph_code( + "Forward graph before unlifting tokens", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + unlift_tokens_from_module( + fw_module, + "forward", + num_forward_tokens, + ) + + if bw_module is not None and num_backward_tokens > 0: + if aot_config.enable_log: + from torch._dynamo.utils import lazy_format_graph_code + + aot_graphs_effects_log.debug( + "%s", + lazy_format_graph_code( + "Backward graph before unlifting tokens", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + unlift_tokens_from_module(bw_module, "backward", num_backward_tokens) + + # This is sad, but we need to update the metadata to get rid of + # the tokens. + fw_metadata.tokens = {} + fw_metadata.num_backward_tokens = 0 + + +def root_module_when_exporting_non_strict(flat_fn): + # When exporting in non-strict mode, we wrap the root module in a specific pattern. + # See `_aot_export_non_strict` in torch.export._trace.py. + # We look for that wrapping pattern here. + if hasattr(flat_fn, "_orig_mod") and hasattr(flat_fn._orig_mod, "_export_root"): + return flat_fn._orig_mod._export_root + else: + return None + + +def _is_forward_node_with_seq_nr(node: torch.fx.Node) -> bool: + # For now, assume that if nn_module_stack_metadata is populated, this + # node is from the forward. Ignore nodes without `seq_nr`. + # TODO(future): there is likely a less brittle way to do this by walking + # the descendants of graph inputs corresponding to fwd inputs, didn't + # seem obvious at first glance on how to partition graph inputs into + # fwd vs bwd without relying on string names. + return node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta + + +def _is_backward_node_with_seq_nr(node: torch.fx.Node) -> bool: + # For now, assume that if nn_module_stack_metadata is not populated, + # this node is from the backward. Ignore nodes without `seq_nr`. + # TODO(future): there is likely a less brittle way to do this, same + # as with the forward. + return node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta + + +def _collect_fwd_nodes_from_subgraph( + fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node] +) -> None: + """Collect forward nodes from a single subgraph into the global mapping.""" + for node in fx_g.graph.nodes: + if not _is_forward_node_with_seq_nr(node): + continue + seq_nr = node.meta["seq_nr"] + if seq_nr in fwd_seq_nr_to_node: + # If we already saw an op with the current `seq_nr`, that means + # that the current op did not create an autograd node, and there + # is no corresponding backward node, so we skip. + continue + fwd_seq_nr_to_node[seq_nr] = node + + +def _copy_metadata_to_bw_nodes_in_subgraph( + fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node] +) -> None: + """Copy metadata from forward nodes to backward nodes in a single subgraph.""" + for node in fx_g.graph.nodes: + annotation_log.debug("node: %s", node.name) + seq_nr = node.meta.get("seq_nr") + annotation_log.debug("seq_nr: %s", seq_nr) + + if not _is_backward_node_with_seq_nr(node): + continue + + # We exclude gradient accumulation nodes from copying tags + if node.meta.get("is_gradient_acc", False): + annotation_log.debug("is_gradient_acc") + continue + + # fwd_node should always exist, but handle non-existence just in case + fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"]) + if fwd_node is not None: + node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack") + node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack") + # TODO: better to change to a specific field of custom? + custom = fwd_node.meta.get("custom") + if custom is not None: + node.meta["custom"] = copy.deepcopy(custom) + + +def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None: + """ + Input: `fx_g` which contains the joint fwd+bwd FX graph created by + aot_autograd. + + This function walks the graph and copies over metadata from forward nodes + to backward nodes, using the `seq_nr` field as a one-to-many mapping + from forward node to backward node. This metadata is useful for performance + profiling and debugging. + + This function supports matching forward and backward nodes across different + subgraphs (e.g., in recursive submodules from HOPs), enabling backward nodes + in any submodule to match forward nodes in any submodule. + """ + + # Build a global mapping of seq_nr to forward nodes across all subgraphs + fwd_seq_nr_to_node: dict[str, torch.fx.Node] = {} + + # First pass: collect all forward nodes from all subgraphs + for submod in fx_g.modules(): + if isinstance(submod, torch.fx.GraphModule): + _collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node) + + if annotation_log.isEnabledFor(logging.DEBUG): + for k, v in fwd_seq_nr_to_node.items(): + annotation_log.debug("forward:: key: %s, value: %s", k, v) + + # Second pass: copy metadata to backward nodes in all subgraphs + # using the global forward mapping + for submod in fx_g.modules(): + if isinstance(submod, torch.fx.GraphModule): + _copy_metadata_to_bw_nodes_in_subgraph(submod, fwd_seq_nr_to_node) + + +def register_buffer_assignment_hook(mod, assigned_buffers): + """ + Register a hook that intercepts buffer assignments. + This is used to detect when a buffer is assigned to, and then we can + map that buffer to the corresponding proxy node in the graph. + """ + + def _map_assigned_buffer_to_proxy(_mod, name, buffer): + # We intercept buffer assignments on the root module through this hook. + if _mod._buffers is mod._buffers: + # either buffer is a functional tensor, which wraps a fake tensor + if isinstance(buffer, FunctionalTensor): + buffer = buffer.from_functional() + # or buffer is a fake tensor + assert isinstance(buffer, FakeTensor) + # The fake tensor in turn is associated with a proxy node. + proxy_mode = torch.fx.experimental.proxy_tensor.get_proxy_mode() + assert proxy_mode is not None + proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot( + buffer, proxy_mode.tracer + ).proxy.node + # We map the assigned buffer to this proxy node. + assigned_buffers[name] = proxy.name + return buffer + + return torch.nn.modules.module.register_module_buffer_registration_hook( + _map_assigned_buffer_to_proxy + ) + + +def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool: + """ + Checks if the module contains any metadata mutation ops. + """ + for node in module.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "tags") + and torch.Tag.inplace_view in node.target.tags + ): + return True + return False + + +def get_cuda_generator_meta_val(device_idx: int): + """ + Get a generator value to use as a meta val + + newly cloned generator will not contain tensors. it is only Generators that are + registered to a CUDAGraph that contain tensors. since this does not contain Tensor + it is fine to use in the meta. + """ + return torch.cuda.default_generators[device_idx].clone_state() + + +def top_saved_tensors_hooks(): + return torch._C._autograd._top_saved_tensors_default_hooks(True) + + +def saved_tensors_hooks_are_inlineable(hooks) -> bool: + if not hooks: + return False + pack, unpack = hooks + return isinstance(pack, torch.fx.GraphModule) and isinstance( + unpack, torch.fx.GraphModule + ) + + +_P = ParamSpec("_P") +_T = TypeVar("_T") +_S = TypeVar("_S") + + +def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]: + @wraps(f) + @simple_wraps(f) + def inner(*args, **kwargs): + # pyrefly: ignore [invalid-param-spec] + return f(*args, **kwargs)[0] + + # pyrefly: ignore [bad-return] + return inner + + +_P2 = ParamSpec("_P2") +_R = TypeVar("_R") +_R2 = TypeVar("_R2") + + +def simple_wraps( + f: Callable[_P, _R], +) -> Callable[[Callable[_P2, _R2]], Callable[_P2, _R2]]: + # NB: omit ('__module__', '__name__', '__qualname__') for ease of + # debugging + return wraps(f, assigned=("__doc__", "__annotations__", "__type_params__")) + + +def call_and_expect_output_descs(fn, args): + outs_pair = fn(*args) + assert isinstance(outs_pair, tuple) and len(outs_pair) == 2, (fn, outs_pair) + outs, outs_descs = outs_pair + # The Tensor tests protects against the test when there are no outputs + out_vals, out_spec = pytree.tree_flatten(outs) + out_desc_vals, out_desc_spec = pytree.tree_flatten(outs_descs) + assert out_spec == out_desc_spec, ( + fn_wrappers(fn), + outs, + outs_descs, + out_spec, + out_desc_spec, + ) + assert not any(isinstance(x, AOTOutput) for x in out_vals), ( + fn_wrappers(fn), + outs, + outs_descs, + out_vals, + ) + assert all( + isinstance(d, AOTOutput) + for (x, d) in zip(out_vals, out_desc_vals) + if isinstance(x, (torch.Tensor, torch.SymInt)) or type(x) is int + ), (fn_wrappers(fn), outs, outs_descs, out_vals, out_desc_vals) + return outs_pair + + +def fn_wrappers(fn): + fns = [fn] + f = fn + while hasattr(f, "__wrapped__"): + f = f.__wrapped__ + fns.append(f) + return fns diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aeef87e5a418189c712902f912c1c7791945d42 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69e4138687bac6163a5b7465358b5e34a756c092 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a883c71a7c13b47f00fefa8d3434a29ae78c15ac Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ced206faf3a0195655d4dddbcac47e5e22587cbc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f09348087ad9ce4f2baa5c949a2d358e9b5ab1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d884a47c01a069f2489c458cccd8bc40d611dd3d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1146f7d255641a3d6f874dffd68a66888cbe6f31 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdaf592744a0f86238234db714cddbb0cf23e8b0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6515a4d0f45e5fdfda9c688f9b1225b752cc749f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a90c65e4f207cd64e8d6daf31b7de5d9a22dccd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c12894a049090629fb975b2a890c76138cf2074 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be301f39b62112e49324752f50fe2be77a011b54 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..046ab54db16cfc4f0164fcf3f473d2f45e7a8335 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22c6ce153e7f1f16ddd10f189e6598c67cf74c76 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/local_map.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/local_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd02fe5d590debb291459d2fcd4476c061b38b4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/local_map.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d76f10668314f458d5d6790870817a014624f252 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1927d2154f8e49675a377a45aa1a18f60ae0bb9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/partitioner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/partitioner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..270a0096d5e951179b0337ed6ffe3c040ff1519d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/partitioner.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/print.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/print.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f916d4aa3712b6e0fab116a30db2729461500c81 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/print.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90e2f63bf42ca64525d171c5e6682825842a9247 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd73b7a422efd05ce53223e75c0c174376353e5b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9d082d386b33132d44d9383ffaff9b49b7f638b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51cfd5af9cad9317375d9fecc49f509adc3f1af7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fa666dfb10b278c16d689291c69811769c380ce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc651dde6eee0d5e9ee3178d207e52a40021b93e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716a5da16601d6bd2cf592269d060727a3807eee Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b33bdc5f095be940c3b575ea3dd9669736b5b3c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a03ef2e3858ab923359ce641984ca332dc1303cc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5da7030243a2866458006aa5f2e6085e00b123a1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8dcc415f5072d32b15c1fb09315e139f9a36ae9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d50b0c67ce830cc049eb092e2a32336160121d47 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3243266ccd548ce5677763ff4f0adacb587bd6be Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/await_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/await_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc8f6f13e092bce125e57f66ae7f24c9930da1cc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/await_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/bounds.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/bounds.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7d34a26e9de154b8a1cdb1dbce54ed3792f155b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/bounds.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cab20ecd891b7e1b5f288f3deb080980e449d164 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8991a6c878ea15456bc1bb1051385d47b02610e6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05c72f630594bcaf5211d0a9de766e456b57dead Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34c2ec281bf91b15532d89c86a6cb409ae8ef888 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms_debug.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms_debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7b1bc85edf252cb8bc903761f184ff50d9cd78f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms_debug.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8df3e5310e394f62af542dfa002315182e41306 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1530690371011792222f91ece91ae9564bc8e334 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config_comms.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config_comms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d2d8cc4eb472a24c88a435162f2dfdf329adec Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config_comms.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..638d70feba3a28a20ea8fab939d0e5ed508e4455 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4dcc5bdfc9ac1aaa9f469137f5a5035785d76c0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87bb5a2927bd335edba260165426881bf4efa780 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..965bdfdf4fd5727cea490477b60c2120bbfc15b9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/debug.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36e58d3e040b22b333f76b6990129e4914638780 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/debug.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/decomposition.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/decomposition.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a974de80c68d7c9e11c84ab23ccac2c301ccbb60 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/decomposition.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/dependencies.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/dependencies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..686198906480c48b1a9d58460c1b164a40309834 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/dependencies.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/distributed_autotune.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/distributed_autotune.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6575f4f8184db0ec9f12ec4ee05c9e423a0952f3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/distributed_autotune.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9497a82d115f1837929b424aea5447c6f9dc510d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/exc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/exc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5f468ca65103f0c40f1531f8abb946dabd032a2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/exc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5353b153af621821c8ba4f0ecd256afe4b517b26 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..199bcc72e51dd36d2d356d632ac02bb1fa0a008a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e52371887147fc5b271f63188c677f67df071fdf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/hooks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60df9590b3086969a9c59e855cfc269c00841956 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/hooks.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/invert_expr_analysis.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/invert_expr_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29c913383de957b59f1e35bfb4bf4fb15bfbb301 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/invert_expr_analysis.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5741600e07bc47a7985eb3bfecd91b27eade72d9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_inputs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a7586441b4c52794db9388584d9c50d45f9d9a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_inputs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_template_choice.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_template_choice.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab3b15eb56c10c03616b5dc2f606d913b16e175a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_template_choice.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/loop_body.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/loop_body.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2006783308b49b027f14e325fb30417b74ddf58 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/loop_body.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/memory.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61286a8469bbb97d85f9d598a17ab59a750467a1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/memory.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/metrics.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b2694a0606421979c75d5c4406ba432e0e683d0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/metrics.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b227af0b82d79edafcd0530a01d8dc50246f687a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cfc130f70683bf48f39158e06f640c58c23ffcc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1844d08efe3ecdebb007cd114b6bba76a24846f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bed971c84a47c0843e39da1d658561b1ca147e45 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d80940afaa8358423feb691ab15b54f09b63f40 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/output_code.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/output_code.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8acce90478e10f04fc899521ae37b39b0fec123c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/output_code.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65cb73888994c6f09cec3e2c9921ab878e7a63fc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/rocm_multiarch_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/rocm_multiarch_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb90781f5754c256ca6283f39b7651e080132932 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/rocm_multiarch_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/shape_propagation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/shape_propagation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5374e7c1597b107d31808eedc8360fce23af585 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/shape_propagation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_case.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_case.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a982a0d7971eab5910269b9fd0fd1f01840411 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_case.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae9e72b913a9f88e688307e74cd5f7366c0cfe31 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50aa7a2904e10135e331eaeea19dd09220ea5823 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/virtualized.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/virtualized.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c44e39cc39efab7161aadce2a7bd9de8b694909 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/virtualized.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce0908cdf19d61ebbb86b6d5105bbce2f0dffb9b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83cbb796adf555672ec631b9f4f3460e9f4d94b8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/device_info.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/device_info.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0e1ec8284ee1f716e8c0ac256e1b4724e1393e3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/device_info.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/profile_analysis.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/profile_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a8392f1c1e5729a04b79e2a09e3b5fe23ea6272 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/profile_analysis.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a94ef9c5de225c6400bcab2323fb7f0af49d027 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e3487eee5e973e3c499b4415d8326646d41564a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7544db4d09ea02093e85faff51b59cc38aad41f0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3852d526b59de3ae7204385783bdaf274a68aab Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37e7c4e76711170c46b4f4db85fe3cffc1e5dd15 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebf134c83d7c597dae05f572f6f6f7f702c9f6e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py @@ -0,0 +1,296 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 52.6245059967041: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 312.0: + return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)] + else: + if context.get_value('k') <= 40.0: + return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)] + else: + return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)] + else: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)] + else: + if context.get_value('k') <= 68.0: + return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)] + else: + return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)] + else: + if context.get_value('k') <= 35.0: + if context.get_value('k') <= 18.0: + if context.get_value('m*n') <= 19505152.0: + return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)] + else: + return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)] + else: + if context.get_value('n') <= 68.0: + return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)] + else: + return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)] + else: + if context.get_value('m*n') <= 309760.0: + return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)] + else: + if context.get_value('n') <= 72.0: + return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)] + else: + return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)] + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 815360.0: + if context.get_value('k') <= 1184.0: + return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)] + else: + return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)] + else: + if context.get_value('arith_intensity') <= 187.23922729492188: + if context.get_value('mat1_stride_0') <= 198.0: + return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)] + else: + return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)] + else: + return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)] + else: + return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py new file mode 100644 index 0000000000000000000000000000000000000000..6201acc4213aa153cc73971946d3a241d2063793 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py @@ -0,0 +1,321 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 29.89772129058838: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 432.0: + if context.get_value('arith_intensity') <= 7.8700292110443115: + return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)] + else: + return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)] + else: + if context.get_value('k') <= 40.0: + return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)] + else: + return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)] + else: + if context.get_value('mat1_stride_0') <= 40.0: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)] + else: + return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)] + else: + if context.get_value('mat1_stride_0') <= 68.0: + return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)] + else: + return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)] + else: + if context.get_value('k') <= 18.0: + if context.get_value('m*k') <= 528.0: + return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)] + else: + if context.get_value('n') <= 80.0: + return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)] + else: + return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)] + else: + if context.get_value('k') <= 36.0: + if context.get_value('n') <= 68.0: + return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)] + else: + return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)] + else: + if context.get_value('mat2_stride_0') <= 384.0: + return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)] + else: + return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)] + else: + if context.get_value('arith_intensity') <= 56.995582580566406: + if context.get_value('n') <= 68.0: + if context.get_value('k*n') <= 4448.0: + if context.get_value('m*n') <= 29626368.0: + return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)] + else: + return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)] + else: + if context.get_value('k') <= 348.0: + return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)] + else: + return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)] + else: + if context.get_value('m') <= 3264.0: + return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)] + else: + if context.get_value('k') <= 62.5: + return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)] + else: + return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)] + else: + if context.get_value('m*n') <= 1097728.0: + return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)] + else: + if context.get_value('m*n') <= 3244032.0: + return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)] + else: + if context.get_value('n') <= 136.0: + return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)] + else: + return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba7cbaf90275d1bb2cb50e8fd27fbd331173bbb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py @@ -0,0 +1,150 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if str(context.get_value('1LEQmLEQ16')) != 'True': + if context.get_value('m') <= 32.5: + if context.get_value('n') <= 6976.0: + if context.get_value('n') <= 3520.0: + if context.get_value('m*n') <= 37632.0: + return None + else: + return [(1.000, 13)] + else: + if context.get_value('m*k') <= 452352.0: + return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)] + else: + return [(0.778, 8), (0.222, 13)] + else: + if context.get_value('k*n') <= 102776832.0: + if context.get_value('n') <= 14656.0: + return [(1.000, 11)] + else: + return [(0.889, 11), (0.111, 13)] + else: + return [(1.000, 11)] + else: + if context.get_value('m*n') <= 446464.0: + if context.get_value('m*n') <= 223424.0: + if context.get_value('mat1_stride_0') <= 3968.0: + return None + else: + return None + else: + if context.get_value('m*n') <= 346112.0: + return [(0.960, 16), (0.040, 7)] + else: + return [(0.750, 16), (0.136, 14), (0.114, 7)] + else: + if str(context.get_value('33LEQmLEQ64')) != 'True': + if context.get_value('n') <= 6976.0: + return [(1.000, 14)] + else: + return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)] + else: + if context.get_value('n') <= 13888.0: + return [(0.710, 14), (0.275, 21), (0.014, 12)] + else: + return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)] + else: + if context.get_value('n') <= 3520.0: + if context.get_value('arith_intensity') <= 3.994754433631897: + if str(context.get_value('mat2_dtype')) != 'torch.uint8': + if context.get_value('m*k') <= 18944.0: + return [(0.577, 5), (0.423, 6)] + else: + return [(0.988, 5), (0.012, 6)] + else: + if context.get_value('arith_intensity') <= 2.9899919033050537: + return None + else: + return None + else: + if context.get_value('arith_intensity') <= 7.956453561782837: + if context.get_value('k*n') <= 9244032.0: + return [(0.822, 5), (0.178, 6)] + else: + return [(0.977, 5), (0.023, 0)] + else: + if context.get_value('m*k') <= 978944.0: + return [(1.000, 5)] + else: + return [(0.971, 5), (0.029, 0)] + else: + if context.get_value('n') <= 13632.0: + if context.get_value('n') <= 6976.0: + return [(1.000, 6)] + else: + if context.get_value('k') <= 3968.0: + return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)] + else: + return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)] + else: + if context.get_value('k*n') <= 39518208.0: + return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)] + else: + if context.get_value('n') <= 20800.0: + return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)] + else: + return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe46cf75d8c63fab36eef728edf34788d6e3b22 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py @@ -0,0 +1,149 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 15.988086223602295: + if context.get_value('n') <= 25280.0: + if context.get_value('n') <= 1344.0: + if context.get_value('mat1_stride_0') <= 7808.0: + return [(0.581, 7), (0.419, 6)] + else: + if context.get_value('m*n') <= 7680.0: + return [(0.875, 0), (0.125, 6)] + else: + return [(0.833, 0), (0.167, 7)] + else: + if context.get_value('n') <= 8512.0: + if str(context.get_value('mat2_dtype')) != 'torch.int8': + return [(0.763, 6), (0.237, 7)] + else: + return [(0.725, 7), (0.275, 6)] + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)] + else: + return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)] + else: + if context.get_value('n') <= 42254.0: + if context.get_value('n') <= 33856.0: + if context.get_value('k*n') <= 68157440.0: + return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)] + else: + return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)] + else: + return [(0.659, 5), (0.341, 6)] + else: + if context.get_value('k*n') <= 326052992.0: + if context.get_value('n') <= 55232.0: + return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)] + else: + return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)] + else: + if context.get_value('n') <= 57024.0: + return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)] + else: + return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)] + else: + if context.get_value('m*n') <= 543936.0: + if str(context.get_value('17LEQmLEQ32')) != 'True': + if context.get_value('m*n') <= 262272.0: + if context.get_value('n') <= 1592.5: + return [(0.860, 0), (0.140, 9)] + else: + return None + else: + if context.get_value('m*k') <= 1294336.0: + return [(0.833, 17), (0.150, 18), (0.017, 15)] + else: + return [(0.917, 17), (0.083, 8)] + else: + if context.get_value('n') <= 12416.0: + if context.get_value('m*n') <= 43008.0: + return None + else: + return [(0.853, 14), (0.147, 9)] + else: + return [(0.625, 12), (0.375, 14)] + else: + if context.get_value('m') <= 32.5: + if context.get_value('mat2_stride_1') <= 6656.0: + if context.get_value('n') <= 69184.0: + return [(0.611, 12), (0.361, 14), (0.028, 13)] + else: + return [(1.000, 12)] + else: + if context.get_value('mat2_stride_1') <= 20864.0: + return [(1.000, 12)] + else: + return [(0.958, 12), (0.042, 9)] + else: + if context.get_value('m*n') <= 1085440.0: + if context.get_value('n') <= 9152.0: + return [(1.000, 18)] + else: + return [(0.780, 18), (0.160, 16), (0.060, 20)] + else: + if context.get_value('m') <= 67.0: + return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)] + else: + return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..b61f8a9dd1e99056864a9dddc663b090f6971214 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py @@ -0,0 +1,109 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ +from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicRegression, +) + + +class PadMMA100(LearnedHeuristicRegression): + + def __init__(self) -> None: + pass + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + context.context_dict[CHOICE_COL] = choice + return self.predict(context) + + def get_confidence_threshold(self) -> float: + return 1.7025303314066 + + def get_name(self) -> str: + return 'pad_mm' + + def predict(self, context: AHContext) -> float: + if str(context.get_value('choice')) != 'pad': + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 4171264.0: + if context.get_value('m*k') <= 3999308.0: + return 1.8751469764071178 + else: + if str(context.get_value('n_multiple_32')) != 'True': + return 0.9117231355626345 + else: + return 1.1607689608873861 + else: + if str(context.get_value('n_multiple_2')) != 'True': + if str(context.get_value('using_tf32')) != 'True': + return 0.7430382200435992 + else: + return 0.8531269794448678 + else: + if str(context.get_value('k_multiple_2')) != 'True': + return 0.7577181972719917 + else: + return 0.8977349440424219 + else: + if context.get_value('m*n') <= 1299712.0: + return 1.1669723418995592 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + if context.get_value('m*n') <= 55884158.0: + return 1.0262769936909601 + else: + return 1.0022677428470845 + else: + if context.get_value('m') <= 18478.0: + return 1.1127066261894312 + else: + return 1.0337740659894263 + else: + if str(context.get_value('mat1_dtype')) != 'torch.float32': + if str(context.get_value('n_multiple_2')) != 'False': + if str(context.get_value('k_multiple_2')) != 'True': + if context.get_value('mat1_stride_0') <= 561.0: + return 1.2900382135142956 + else: + return 1.5761737616057887 + else: + if context.get_value('num_dims_needs_padding') <= 1.5: + return 1.0472263310239422 + else: + return 1.1727673465762514 + else: + if context.get_value('k') <= 28238.5: + if context.get_value('k/(m*n)') <= 0.00026227018679492176: + return 1.6770542505397175 + else: + return 1.3974785435105923 + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return 1.3952699800111992 + else: + return 1.5759286511628336 + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 14119424.0: + return 0.8875772670422478 + else: + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + return 1.1467728924377265 + else: + return 1.215842963532998 + else: + if context.get_value('arith_intensity') <= 396.8774871826172: + return 0.89940161869551 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + return 0.9964328169353532 + else: + return 0.9493479238294826 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72f7967ce9ac6e0d6467dcde9fe8e56b3750ca52 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b56e25d26b553d93c9b49ba66e536087308d293d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84e6511c6a36d93a7969778cbe08e8034913d773 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c00e334f0fc303d12898b6a2f2ea360fdb4eacac Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67f28104f16ddd8cc5be6eb475a831aef26928e1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e9d17fe176c8f1c9de68ee183950790fba7a811 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f62128f97e7ff509677990b876c8c8e2b3c2afbe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcab83b8a423d70258c49e3260ce712c42b47cad Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e12a86af8ab0ab8d7d7b2d8bf37ec6dec861e0ff --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py @@ -0,0 +1,6 @@ +import torch + + +__version__ = torch.version.cuda + +from .cuda import * # noqa: F403 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8ff8214d81a30e4e749e02a7e51cea520597281 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cuda.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cuda.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b2c3d2ec4a0fbe6891a42a07827b74f939965f1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cuda.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cudart.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cudart.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b389c88a456824ac55225c54805ed5ec076e7f9e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cudart.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..ad41f04fc897e33f4530eb42c76a104def58f413 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py @@ -0,0 +1,24 @@ +# mypy: disable-error-code="no-untyped-def" +# flake8: noqa +import torch + + +class CUdeviceptr: + pass + + +class CUstream: + def __init__(self, v): + pass + + +class CUresult: + CUDA_SUCCESS = True + + +class nvrtc: + pass + + +def cuDeviceGetCount(): + return (CUresult.CUDA_SUCCESS, torch.cuda.device_count()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2ee5f1f6163d7b20336d6102ce5d8f97880c87 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py @@ -0,0 +1,17 @@ +# mypy: disable-error-code="no-untyped-def" +import torch.cuda + + +class cudaError_t: + cudaSuccess = True + + +def cudaFree(n): + return (cudaError_t.cudaSuccess,) + + +def cudaGetDeviceProperties(d): + class DummyError: + value = False + + return (DummyError(), torch.cuda.get_device_properties(d)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8aefb6171b682f062cfe57a1876f51b280f120cc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py @@ -0,0 +1,2 @@ +# mypy: disable-error-code="var-annotated" +Dot = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59f95ecc79040294a0f2348f65f2badd3c4e1423 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0378d35a9c442559373f035e45de19b2be927cd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py @@ -0,0 +1,3 @@ +# typing: ignore +# flake8: noqa +from .special import * diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d45030b7fdf66caeede6c104c9216bbd78aa87 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/special.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/special.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2906ca0d59c01af72b96a0a6ca755b02c8cbb923 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/special.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py new file mode 100644 index 0000000000000000000000000000000000000000..79af3029aa0b18d0ad55633f8cca8af8b76b520b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py @@ -0,0 +1,2 @@ +# mypy: disable-error-code="var-annotated" +erf = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ae3375c9475619cd6c88ffa4703626f55ad8113 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/_cutedsl_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/_cutedsl_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d19f61acc3ef713b83a1ea34190bc1e78d29c60 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/_cutedsl_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_kernel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6ab497f6544ac8a1b312619be2a9832956e06d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_kernel.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_op_overrides.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72d6d64ec810e83dac3e1bc555c9fd08525c25b0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_op_overrides.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_scheduling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705605e5e659fb66af7e7f3540aad61a0295a717 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_scheduling.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4831cddf7dd6ec7e77919f1ef86250392aef6fb3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65205d5ba0269b367ac671d4c5f9109c0b6f546a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/device_op_overrides.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..580c027a686e2d7b84f23e225cd891038018ca58 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/device_op_overrides.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da1a93b8bc513294714bae6e0b5be1b5bcdf5788 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd5c1f4aa861f9edab4533cdb4ab45a50e9c1a5a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbf58984632bc96b1f1b8d90582d5203a8dfba25 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9cd19a88aecec2fbc52e4adb969abbebd404cdf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b5bd05d59ddc5029e784ad2c465b827da7f3333 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f3bd2cc3add2b17761adf524106e1089358fce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..086af579c6b96d4e9ecf63e3e3d38ef596d09563 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c71312750c1fd8c680721ff370aabfb106c1fe8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f469ef76c804cc78cb293158e42c1f4a34a7030 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a647361e3cf0f48888f2da81876869321a47518c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62038b2b565d880fe4838387e49f072f3d840dc9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f7401c79bdb9c8356ab190e170ab21f276d6cac Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cec6d9b598358f15df36d875977ad6186dd0a3b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6acafc7744d4872089b954745a7edffd03eda86c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd90e2234caab727ff5348e91d6561b6acc5e18f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca0f1e5a4fb2a6aeb1224285d76e78a05a0f499 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +import argparse +import base64 +import functools +import importlib +import logging +import os +import sys +from typing import TypeVar + +from torch._inductor.async_compile import pre_fork_setup +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.subproc_pool import ( + SubprocKind, + SubprocMain, + SubprocPickler, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path + + +_T = TypeVar("_T") + + +log = logging.getLogger(__name__) + +_set_triton_ptxas_path() + +try: + import triton + + assert triton is not None # preload in parent +except ImportError: + pass + + +def _lookup_and_create_type(base: type[_T], qname: str) -> _T: + """ + Given a base type and qualified name: import & lookup that name, check + that it's of the given type and then instantiate it. + """ + pkg, name = qname.rsplit(".", 1) + mod = importlib.import_module(pkg) + ty = getattr(mod, name) + if not issubclass(ty, base): + raise TypeError(f"Type {ty} is not a subtype of {base}") + return ty() + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument( + "--pickler", type=functools.partial(_lookup_and_create_type, SubprocPickler) + ) + parser.add_argument("--kind", type=SubprocKind) + parser.add_argument("--workers", type=int) + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + parser.add_argument("--torch-key", type=str) + args = parser.parse_args() + if os.getppid() != args.parent: + sys.exit(0) + read_fd = os.fdopen(args.read_fd, "rb") + write_fd = os.fdopen(args.write_fd, "wb") + + pre_fork_setup() + + torch_key.set(base64.b64decode(args.torch_key.encode("utf-8"))) # type: ignore[attr-defined] + + _async_compile_initializer(args.parent) + + SubprocMain(args.pickler, args.kind, args.workers, read_fd, write_fd).main() + except Exception: + log.exception("Uncaught exception in compile_worker subprocess") + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..07c59b8cbb860fd1ed0e1ff1ba6df34979abdf4f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py @@ -0,0 +1,496 @@ +import base64 +import functools +import itertools +import logging +import multiprocessing +import os +import pickle +import struct +import subprocess +import sys +import threading +import traceback +import typing +from collections.abc import Callable +from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from enum import Enum, IntEnum +from typing import Any, IO, Optional, TypeVar +from typing_extensions import Never, ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 +from torch._inductor import config +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.timer import Timer +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.utils import get_ld_library_path, python_subprocess_env +from torch._utils_internal import find_compile_subproc_binary +from torch.monitor import _WaitCounter, _WaitCounterTracker + + +log = logging.getLogger(__name__) + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class MsgHeader(IntEnum): + ERROR = 0 + SHUTDOWN = 1 + QUIESCE = 2 + WAKEUP = 3 + JOB = 4 + + +def _pack_msg(msg_header: MsgHeader, job_id: int, length: int) -> bytes: + return struct.pack("nnn", int(msg_header), job_id, length) + + +def _unpack_msg(data: bytes) -> tuple[MsgHeader, int, int]: + if not data: + return MsgHeader.ERROR, -1, -1 + msg_header, job_id, length = struct.unpack("nnn", data) + return MsgHeader(msg_header), job_id, length + + +msg_bytes = len(_pack_msg(MsgHeader.JOB, 0, 0)) + + +def _send_msg( + write_pipe: IO[bytes], msg_header: MsgHeader, job_id: int = -1, data: bytes = b"" +) -> None: + length = len(data) + write_pipe.write(_pack_msg(msg_header, job_id, length)) + if length > 0: + write_pipe.write(data) + write_pipe.flush() + + +def _recv_msg(read_pipe: IO[bytes]) -> tuple[MsgHeader, int, bytes]: + msg_header, job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) + data = read_pipe.read(length) if length > 0 else b"" + return msg_header, job_id, data + + +class _SubprocExceptionInfo: + """ + Carries exception info from subprocesses across the wire. traceback + objects are not pickleable, so we store the trace as a string and + use it for the message in the exception thrown in the main process. + """ + + def __init__(self, details: str) -> None: + self.details = details + + +class SubprocException(Exception): + """ + Thrown when a job in a subprocess raises an Exception. + """ + + def __init__(self, details: str, name: str = "") -> None: + self.details = details + super().__init__( + f"An exception occurred in a subprocess:\n\nName={name}\n{details}" + ) + + def with_name(self, name: str) -> "SubprocException": + return SubprocException(self.details, name) + + +class SubprocPickler: + """ + Allows a caller to provide a custom pickler for passing data with the + subprocess. + """ + + def dumps(self, obj: object) -> bytes: + return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) + + def loads(self, data: bytes) -> object: + return pickle.loads(data) + + +class SubprocKind(Enum): + FORK = "fork" + SPAWN = "spawn" + + +class SubprocPool: + """ + Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in + a subprocess.Popen() to try to avoid issues with forking/spawning + """ + + def __init__( + self, + nprocs: int, + pickler: Optional[SubprocPickler] = None, + kind: SubprocKind = SubprocKind.FORK, + quiesce: bool = False, + ) -> None: + entry = os.path.join(os.path.dirname(__file__), "__main__.py") + self.pickler = pickler or SubprocPickler() + self.kind = kind + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + torch_key_str = base64.b64encode(torch_key()).decode("utf-8") + + cmd = [ + sys.executable, + entry, + ] + if (binary := find_compile_subproc_binary()) is not None: + cmd = [binary] + + args = [ + f"--pickler={self.pickler.__class__.__module__}.{self.pickler.__class__.__name__}", + f"--kind={self.kind.value}", + f"--workers={nprocs}", + f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", + f"--torch-key={torch_key_str}", + ] + cmd.extend(args) + log_path = None + self.log_file = None + + if config.worker_suppress_logging: + log_path = os.devnull + log.info("Suppressing compile worker output due to config") + else: + log_path = config.torchinductor_worker_logpath + if not log_path: + log_path = config.get_worker_log_path() + + if log_path: + # pyrefly: ignore [bad-assignment] + self.log_file = open(log_path, "w") # noqa:SIM115 + + self.process = subprocess.Popen( + cmd, + env={ + **python_subprocess_env(), + # Safeguard against creating a SubprocPool in the subprocess. + "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": get_ld_library_path(), + }, + pass_fds=(subproc_read_fd, subproc_write_fd), + stdout=self.log_file, + stderr=self.log_file, + ) + self.write_lock = threading.Lock() + self.read_thread = threading.Thread( + target=self._read_thread, name="InductorSubproc", daemon=True + ) + + self.futures_lock = threading.Lock() + self.pending_futures: dict[int, Future[Any]] = {} + # The pending waitcounter, is used to indicate the time when we have any specific job running. + self.pending_waitcounters: dict[int, Any] = {} + self.job_id_count = itertools.count() + + # The running waitcounter indicates the time when the SubProcPool object exists. + self.running = True + self.running_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.running" + ).guard() + self.running_waitcounter.__enter__() + + # The quiesce waitcounter indicates when the job is in a quiesced state. + self.quiesce_waitcounter: Optional[_WaitCounterTracker] = None + + # Firstjob is used to capture the time from when the firstjob is queued, to when the first job is done. + self.firstjob = True + self.firstjob_id: Optional[int] = None + self.firstjob_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.first_job" + ).guard() + + if quiesce: + self.timer: Optional[Timer] = Timer( + config.quiesce_async_compile_time, self.quiesce + ) + else: + self.timer = None + + # Start thread last to ensure all member variables are initialized + # before any access. + self.read_thread.start() + + def submit( + self, job_fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_T]: + if args or kwargs: + # pyrefly: ignore [bad-assignment] + job_fn = functools.partial(job_fn, *args, **kwargs) + job_data = self.pickler.dumps(job_fn) + future: Future[_T] + with self.futures_lock: + job_id = next(self.job_id_count) + self.pending_futures[job_id] = future = Future() + self.pending_waitcounters[job_id] = _WaitCounter( + "pytorch.wait_counter.subproc_pool.job" + ).guard() + self.pending_waitcounters[job_id].__enter__() + if self.quiesce_waitcounter: + self.firstjob = True + self.quiesce_waitcounter.__exit__() + self.quiesce_waitcounter = None + # This can be entered from either quiesce wakeup, or from startup. + if self.firstjob: + self.firstjob_id = job_id + self.firstjob_waitcounter.__enter__() + self.firstjob = False + future.set_running_or_notify_cancel() + self._send(MsgHeader.JOB, job_id, job_data) + return future + + def _send(self, msg_header: MsgHeader, job_id: int = -1, data: bytes = b"") -> None: + with self.write_lock: + if not self.running: + raise RuntimeError("Attempting to use a closed pool") + _send_msg(self.write_pipe, msg_header, job_id, data) + + def _read_thread(self) -> None: + while True: + data = b"" + job_id = -1 + try: + msg_header, job_id, data = _recv_msg(self.read_pipe) + except Exception: + # Something went wrong during the read. There's no way we have a + # valid msg. + log.exception("failure in subproc_pool._recv_msg") + msg_header = MsgHeader.ERROR + + if msg_header != MsgHeader.JOB: + # read_pipe returned None or got exception + if self.running: + log.warning("SubprocPool unclean exit") + self.running = False + self.running_waitcounter.__exit__() + self.read_pipe.close() + # Cancel all the pending futures. + self.shutdown() + return + + try: + result = self.pickler.loads(data) + except Exception as e: + # Something went wrong unpickling. We have a job_id so just + # notify that particular future and continue on. + log.exception("unpickle failure in SubprocPool._read_thread") + result = e + + with self.futures_lock: + if not self.running: + return + if self.timer: + self.timer.record_call() + if isinstance(result, _SubprocExceptionInfo): + # An exception occurred in the submitted job + self.pending_futures[job_id].set_exception( + SubprocException(result.details) + ) + elif isinstance(result, Exception): + # An exception occurred in some of our subprocess machinery. + self.pending_futures[job_id].set_exception(result) + else: + self.pending_futures[job_id].set_result(result) + + self.pending_waitcounters[job_id].__exit__() + del self.pending_waitcounters[job_id] + if self.firstjob_id == job_id: + self.firstjob_waitcounter.__exit__() + + del self.pending_futures[job_id] + + def quiesce(self) -> None: + self._send(MsgHeader.QUIESCE) + if self.quiesce_waitcounter is None: + self.quiesce_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.quiesced" + ).guard() + self.quiesce_waitcounter.__enter__() + + def wakeup(self) -> None: + self._send(MsgHeader.WAKEUP) + + def shutdown(self) -> None: + try: + with self.write_lock: + if not self.running: + return + if self.timer: + self.timer.quit() + self.running = False + self.running_waitcounter.__exit__() + _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) + self.write_pipe.close() + self.process.wait(300) + if self.log_file: + self.log_file.close() + except OSError: + log.warning("Ignored OSError in pool shutdown", exc_info=True) + finally: + with self.futures_lock: + for future in self.pending_futures.values(): + if not future.cancel(): + future.set_exception(RuntimeError("SubprocPool closed")) + self.pending_futures.clear() + + +class SubprocMain: + """Communicates with a SubprocPool in the parent process, called by __main__.py""" + + def __init__( + self, + pickler: SubprocPickler, + kind: SubprocKind, + nprocs: int, + read_pipe: IO[bytes], + write_pipe: IO[bytes], + ) -> None: + self.pickler = pickler + self.kind = kind + self.read_pipe = read_pipe + self.write_pipe = write_pipe + self.write_lock = threading.Lock() + self.nprocs = nprocs + self.pool: Optional[ProcessPoolExecutor] = None + self.running = True + + def main(self) -> None: + while True: + msg_header, job_id, data = _recv_msg(self.read_pipe) + if msg_header == MsgHeader.JOB: + self.submit(job_id, data) + elif msg_header == MsgHeader.WAKEUP: + self._start_pool() + elif msg_header == MsgHeader.QUIESCE: + self._quiesce() + else: + return self._shutdown() + + def _quiesce(self) -> None: + if self.pool is not None: + self.pool.shutdown(wait=False) + self.pool = None + + def _shutdown(self) -> None: + with self.write_lock: + self.running = False + try: + _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) + self.write_pipe.close() + except BrokenPipeError: + pass # parent process already shutdown + self.read_pipe.close() + self._quiesce() + + def submit(self, job_id: int, data: bytes) -> None: + while self.running: + try: + self._submit_inner(job_id, data) + return + except BrokenProcessPool: + # If any subprocess in the pool crashes, we get a BrokenProcessPool + # exception and the whole pool becomes unusable. Handle crashes by + # recreating the pool and resubmitting. + self.pool = None + + def _submit_inner(self, job_id: int, data: bytes) -> None: + def callback(fut: Future[Any]) -> None: + if not self.running: + return + try: + result = fut.result() + except Exception as e: + log.exception("Error in subprocess") + result = self.pickler.dumps(e) + assert isinstance(result, bytes) + with self.write_lock: + if self.running: + _send_msg(self.write_pipe, MsgHeader.JOB, job_id, result) + return + + self._start_pool() + assert self.pool is not None + + future = self.pool.submit( + functools.partial(SubprocMain.do_job, self.pickler, data) + ) + future.add_done_callback(callback) + + def _start_pool(self) -> None: + if self.pool is not None: + return + + self.pool = TrackedProcessPoolExecutor( + self.nprocs, + mp_context=multiprocessing.get_context(self.kind.value), + initializer=functools.partial(_async_compile_initializer, os.getpid()), + ) + multiprocessing.util.Finalize( + None, self.pool.shutdown, exitpriority=sys.maxsize + ) + _warm_process_pool(self.pool, self.nprocs) + + @staticmethod + def do_job(pickler: SubprocPickler, data: bytes) -> bytes: + # do the pickle/unpickle in the sub-subproc + job = typing.cast(Callable[[], object], pickler.loads(data)) + + try: + result = job() + except Exception: + result = _SubprocExceptionInfo(traceback.format_exc()) + return pickler.dumps(result) + + +AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool] + + +def _warm_process_pool(pool: ProcessPoolExecutor, n: int) -> None: + # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the + # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread. + + # Examples: + # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup + # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup + + # So we want to start the workers early when it is still cheap, and also to allow the workers to get + # ready before we have work for them. + + # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle. + # But if we waited until then fork time will be long and we will be waiting for the processes to initialize. + + # We force them to start here with some YOLOing of the internal methods. + + if hasattr(pool, "_start_queue_management_thread"): + pool._start_queue_management_thread() + else: + for _ in range(n): + pool._adjust_process_count() + if hasattr(pool, "_start_executor_manager_thread"): + pool._start_executor_manager_thread() + + +class TestException(RuntimeError): + pass + + +def raise_testexc() -> Never: + raise TestException diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/timer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c495403b3a55ef8858bd6661607d7bcf25674e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/timer.py @@ -0,0 +1,55 @@ +from collections.abc import Callable +from threading import Lock, Thread +from time import monotonic, sleep +from typing import Optional, Union + + +class Timer: + """ + This measures how long we have gone since last receiving an event and if it is greater than a set interval, calls a function. + """ + + def __init__( + self, + duration: Union[int, float], # Duration in seconds + call: Callable[[], None], # Function to call when we expire + ) -> None: + # We don't start the background thread until we actually get an event. + self.background_thread: Optional[Thread] = None + self.last_called: Optional[float] = None + self.duration = duration + self.sleep_time = duration / 2 + self.call = call + self.exit = False + + self.lock = Lock() + + def record_call(self) -> None: + with self.lock: + if self.background_thread is None: + self.background_thread = Thread( + target=self.check, daemon=True, name="subproc_worker_timer" + ) + self.background_thread.start() + self.last_called = monotonic() + + def quit(self) -> None: + with self.lock: + self.exit = True + + def check(self) -> None: + while True: + # We have to be sensitive on checking here, to avoid too much impact on cpu + sleep(self.sleep_time) + with self.lock: + if self.exit: + return + assert self.last_called is not None + if self.last_called + self.duration >= monotonic(): + continue + self.last_called = None + self.background_thread = None + + # Releasing lock in case self.call() takes a very long time or is reentrant + self.call() + return diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..546a5cbc6395a104cede30dd94054cfb12193a1b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py @@ -0,0 +1,113 @@ +import atexit +import concurrent +import dataclasses +import logging +import threading +from collections.abc import Callable +from concurrent.futures import Future, ProcessPoolExecutor +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from time import time +from typing import Any, Optional, TypeVar +from typing_extensions import ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +log = logging.getLogger(__name__) + + +@dataclass +class _QueueStats: + # Mapping from id(future) -> start time + pending: dict[int, float] = dataclasses.field(default_factory=dict) + timing: list[float] = dataclasses.field(default_factory=list) + enqueue_count: int = 0 + dequeue_count: int = 0 + max_queue_depth: int = 0 + pool_count: int = 0 + + +# The queue statistics tracked by TrackedProcessPoolExecutor. Always grab +# _queue_stats_lock before touching. +_queue_stats = _QueueStats() +_queue_stats_lock = threading.Lock() + + +class TrackedProcessPoolExecutor(ProcessPoolExecutor): + def __init__( + self, + max_workers: Optional[int] = None, + mp_context: Optional[BaseContext] = None, + initializer: Optional[Callable[[], object]] = None, + ) -> None: + with _queue_stats_lock: + _queue_stats.pool_count += 1 + super().__init__(max_workers, mp_context, initializer) + + def _record_dequeue(self, f: Future[Any]) -> None: + now = time() + with _queue_stats_lock: + stats = _queue_stats + if (start_time := stats.pending.pop(id(f), None)) is None: + return + stats.dequeue_count += 1 + duration = now - start_time + stats.timing.append(duration) + + def _record_enqueue(self, f: Future[Any]) -> None: + # Monkeypatch the set_running_or_notify_cancel so we can track when the Future moves out of PENDING. + saved_running_or_notify_cancel = f.set_running_or_notify_cancel + + def set_running_or_notify_cancel() -> Any: + self._record_dequeue(f) + return saved_running_or_notify_cancel() + + now = time() + with _queue_stats_lock: + stats = _queue_stats + stats.pending[id(f)] = now + stats.enqueue_count += 1 + stats.max_queue_depth = max(stats.max_queue_depth, len(stats.pending)) + f.set_running_or_notify_cancel = set_running_or_notify_cancel # type: ignore[method-assign] + + if f._state != concurrent.futures._base.PENDING: + self._record_dequeue(f) + + def submit( + self, fn: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_R]: + # pyrefly: ignore [bad-argument-type] + f = super().submit(fn, *args, **kwargs) + self._record_enqueue(f) + return f + + +@atexit.register +def _queue_stats_report() -> None: + stats = _queue_stats + if stats.pool_count == 0: + return + + timing = stats.timing + timing.sort() + + log.info("AsyncCompile Metrics:") + log.info(" Pools %s", stats.pool_count) + log.info( + " Items %d enqueued / %d dequeued", stats.enqueue_count, stats.dequeue_count + ) + log.info(" Max Queue Depth: %d", stats.max_queue_depth) + n = len(timing) + if n > 0: + log.info(" Longest queue time: %0.2fs", timing[-1]) + log.info(" P50: %0.2fs", timing[n // 2]) + if n >= 20: + log.info(" P95: %0.2fs", timing[n * 95 // 100]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b5e21630c270ada0f45a1f3ff318620fa2deba --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py @@ -0,0 +1,54 @@ +import os +import signal +from threading import Thread +from time import sleep +from typing import Optional + + +_IN_TOPLEVEL_PROCESS = True + + +def in_toplevel_process() -> bool: + global _IN_TOPLEVEL_PROCESS + return _IN_TOPLEVEL_PROCESS + + +# If this process dies abnormally (e.g. segfault) +# it will not shut down the workers. Instead, +# the workers will have their parent reassigned to the +# init process. This launches a separate thread to +# watch for the worker getting reassigned, +# and cleans it up in this case. +# +# This function cannot be an inner function since otherwise mp_context="spawn" would +# not work for ProcessPoolExecutor since inner functions cannot be pickled. +def _async_compile_initializer(orig_ppid: int) -> None: + import torch._C + + def run() -> None: + while True: + sleep(60) + if orig_ppid != os.getppid(): + os.kill(os.getpid(), signal.SIGKILL) + + global _watchdog_thread, _original_parent + _original_parent = orig_ppid + _watchdog_thread = Thread(target=run, daemon=True) + _watchdog_thread.start() + # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Install a crash handler to print out the stacktrace for SEGV + torch._C._initCrashHandler() + + # Set a bit to distinguish async_compile subprocesses from the toplevel process. + global _IN_TOPLEVEL_PROCESS + _IN_TOPLEVEL_PROCESS = False + + +_watchdog_thread: Optional[Thread] = None +_original_parent: Optional[int] = None + + +def has_parent_changed() -> bool: + return _original_parent != os.getppid() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a70cf5fb9b0a7b2df1e22ebabf0e67a46f0afa0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37c24545669317969d5f783fbce7938d0df20ac1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b49f42a0dce967ac531d9fd7bdd38ef9d6fa38a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/bucketing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/bucketing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a5d29cdf43bdc9163ee42c2ac2fa6721a186cb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/bucketing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/control_dependencies.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/control_dependencies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7be49780662cdc4701dabfbf0e862f8c07785daf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/control_dependencies.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..446d81c834a0b367c288ea493c7ea14d6b634fcc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63b9281ea01e18e3399f917412ed07c6c7183ad0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6665ddb61799235840127c88f369e79c5ba900f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cd35b59b7c5e48147ff83c0c8c45cc7c0a22f19 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac370a2768dc19be86c94359911052f1b278253 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fsdp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42bf6e3f5a819cc13e958fcef5e8cbd351bdd5ce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fsdp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99dabf68cc3da5df88a776739f5c4702c4ba45ec Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/graph_view.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/graph_view.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dfa74547bb140734d2acb5b5c54498ab5859c9d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/graph_view.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..144c55be8d491ce745a8c94790a74b2789e8881a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aad94e9447f6432974ae8879b93e8f622201b2a8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/memory_estimator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/memory_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d7965ff63f756e7324fe7974e6dbd2467c45422 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/memory_estimator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c4dede6fc66838b755c43940239c9e15efde148 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f2eca8c18ec4623713c569434fa83b864c669d3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9e955caa7a83fd7959f6969a774322185b9209c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/node_runtime_estimation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/node_runtime_estimation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ac2d0c1ee51096cb3d7f8d3ff24b1bd19d827d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/node_runtime_estimation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8e65cc609f3892e55c44c3d65fd243763113c50 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_manual_scheduling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_manual_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13471fe81d75b9b3236a1c1dde4ea26ce45a5df1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_manual_scheduling.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_preserving_bucketer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_preserving_bucketer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e61b4053966a5b9be82fbe43dbaab8af140652 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_preserving_bucketer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_scheduling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef1c43c1cd399443c842e0e9317877984f585a97 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_scheduling.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cb391693f8ba04fd209287345277963ab551830 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f0f8b92c16e5c688c0424a571f9394337241bd2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bca053a5e20a73d30f446a569433900967608f1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e23ad75ccec89a62dd8787e3d6b287c012358443 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ded33a2bd973e070efee56466cc1b7e1db6f46f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ce38eacaa49dbd1d7ec3737bc5195b8e0645066 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ebfc54ca9c5795a954bfa9af0f9c3d79796b090 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e4d8ced6b023e3f6626c552e0a8ad1b86bf58a0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe198068a281e4049cd4058b9d6fd9171a55018b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..978f31d3d278fd44becfd38d2788d7cc2f5ae877 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3d539e4402f354d887fe9cc609cc38fe5aea948 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95403e65d73ab3c3fa3d537b2c68a743bc946d93 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..772244e9626ad8ed398ae7c8513ec316d1794c6d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa7b5db428fd5c8a43610bcf9553c9c8b4edae70 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d13cedb8b16f5cae85c0574e2db2e39bc7b31df9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0c843f896619d38b2d635eb8d288dc4cefcaebc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56fe0373de1accca8e26908e87b9083e049f4640 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaa64987d20faf50051db09f33affd98541df6cf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f62f794dc9afe841bf22bcb3e23db417c88a995 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6173ce0e99ad71a6bc775d600ab45f3ca651ebf1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab26068ffe7c878ff4f761ff6bd4bf733a8e2d75 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e75611de30d77053ef1b6992aaa4d6d2c25ccff4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_24.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_24.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ebdfc1469fc6de1e6aab891995c4591283b4671 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_24.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef1c7200c0a6b486cccc2d359e03b8b98aaf91dd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b0ebf52e54d41e954415aa0380395d99127cc9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9369b82be0c8d1c3a405fff28f9ed8fe33c3f64b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5e158bbe3fe59055fc8c176e35df16fe997d359 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59ef9db08b06492aa2ac22caf868d53c82572392 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f90acdd34543040f1a5a40abacbe25484885df Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d69477c48e91b5fd429b2b5c8740972fa3e14e4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b62ed1b84e3a6a05682aa4040421c0f0e69757b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c855faed49f82575b9fa638218f3436eb8dca01 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f174f229e6cf7bede2c234d4113ed2b65cbd7835 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8e6de3ff3cba5f0ebcec729c33061b04319d6a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py new file mode 100644 index 0000000000000000000000000000000000000000..567390838ede7dc4d4181f601f020e8066cb07b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa39474c67dd677008c8e7e9266cc875a153196 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..87302d1bab3694a33eac14e263dca86c9f702c75 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..d465c1cb4e22b14bfbbdd35d5ed28a43af891523 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,130 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +neg_default = CallFunction(aten.neg.default, div_Tensor) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py new file mode 100644 index 0000000000000000000000000000000000000000..f102038e82c6d5858b8b334e956d87c7e86a9d22 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -0,0 +1,210 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cbb0df340bab14188ec9d5f04a29035dba8d84 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_8, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_8, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py new file mode 100644 index 0000000000000000000000000000000000000000..3a15abb9088ff5ffe8cd9af43df11ccc0d5bc143 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -0,0 +1,599 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py new file mode 100644 index 0000000000000000000000000000000000000000..812708907b3414e2c864ed36b98f2199e63ae5d2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -0,0 +1,246 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py new file mode 100644 index 0000000000000000000000000000000000000000..567d898ed204257e23a2002479afc5d26cba623b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -0,0 +1,453 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6d316351b8595f75fdfd262a4cb2171a8a6b1e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -0,0 +1,209 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py new file mode 100644 index 0000000000000000000000000000000000000000..f28da434ef0c85ca3d80095e68c052e8dc19dd2d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py new file mode 100644 index 0000000000000000000000000000000000000000..9185aa3b1e3305cfa28f8080be04350beb17c065 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py @@ -0,0 +1,244 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py new file mode 100644 index 0000000000000000000000000000000000000000..4ebd4a4e14e48439eaa0a8b50e9fcf72145dc1a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py @@ -0,0 +1,391 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_bs1_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py new file mode 100644 index 0000000000000000000000000000000000000000..0971c09ad972f2bc07ac6ee9f548255a3760faa2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py @@ -0,0 +1,415 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py new file mode 100644 index 0000000000000000000000000000000000000000..2be036c2e8ae7922b51690da782c5565656d7998 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py @@ -0,0 +1,407 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py new file mode 100644 index 0000000000000000000000000000000000000000..72f23373c143e4f113f04d5228966e5e79c448a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py @@ -0,0 +1,153 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=4) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +mul_Tensor = CallFunction(aten.mul.Tensor, bmm_default_2, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, view_default_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, div_Tensor, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +_sfdp_pattern_24_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_1, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, convert_element_type_default, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_half_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +_sfdp_pattern_24_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7f7519ad0570d2c2f700d4081c9b7253d16657 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9cfd506f950415f4f90b49edf83815432a641c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..f211e56b17a0a19c05bcb0efc681ed2623f4edf7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,178 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..01304bf415163909c5ec5b03064ce064697e1de9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,194 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py new file mode 100644 index 0000000000000000000000000000000000000000..b463c7e64a6130dd85063f5fb88c2317c392c8f2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..3faff67089b17ad370d4d7642539c7ce3fd5d235 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf77120e836a5b577ea8a335f00bd63fd27163a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..70d672442170905a411de63187a5b579b286bf73 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py @@ -0,0 +1,53 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) +mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta')) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True) +view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha')) +addmm_pattern_training = MultiOutputPattern([addmm_default, + view_default, + mul_Scalar_1, + mul_Scalar_2, + None, + None +]) + + +addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5ac59d6f06c97523e071e9b3ea78516ff09c0e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1')) +bmm_pattern_training = MultiOutputPattern([bmm_default, + bmm_default_1, + bmm_default_2 +]) + + +bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..058a2f881e3a52cb147cfd3fa0ef2bbd0a25945a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mm_pattern_training = MultiOutputPattern([mm_default, + mm_default_1, + mm_default_2 +]) + + +mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48af3b6ac6805af428b9923d3d5d6657b7af6193 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f00ed25f79746a623201cda0701a030a89e8ab7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41db8939d197a18a43e950f9b0a77c5195c96737 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/custom_op.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/custom_op.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d62b5328e78db361f557a35e1f9265398b3063c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/custom_op.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78f04511cde116e741ffb8c1964cf06a70af043e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36afecf4142e7c46c957bb635f25d18236a046a9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_grouped.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_grouped.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfccfee50f1a0fb84e21575db804701edbaca397 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_grouped.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3541b1a92fd5fdb132239305d2bb1952714e0b5a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67a604adcb1e6057015f7fa1833d766b37d7c61b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__init__.py @@ -0,0 +1,3 @@ +# mypy: allow-untyped-defs +# Import so here and then reimport above so that register_lowering gets triggered +from . import flex_attention, flex_decoding diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b604514f30d1436de9db6433e00fea28a621e8fc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/common.py @@ -0,0 +1,356 @@ +# mypy: allow-untyped-defs +"""Common utilities and functions for flex attention kernels""" + +import math +from collections.abc import Sequence +from functools import partial +from pathlib import Path +from typing import Any, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map, tree_map_only + + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda_combined_scheduling import _IntLike +else: + _IntLike = Union[int, sympy.Expr] + + +from ...ir import ( + ComputedBuffer, + ExternKernel, + FixedLayout, + FlexibleLayout, + get_fill_order, + InputBuffer, + IRNode, + MutationLayoutSHOULDREMOVE, + Scatter, + ShapeAsConstantBuffer, + StorageBox, + Subgraph, + TensorBox, +) +from ...lowering import ( + _full, + check_and_broadcast_indices, + expand, + index_output_size_and_inner_fn, + to_dtype, +) +from ...select_algorithm import realize_inputs +from ...utils import load_template + + +SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] + + +def zeros_and_scatter_lowering(shape: list[int], indices, values): + """To support backwards on captured buffers we register a specific lowering for our specific custom up""" + # Always accumulate into fp32 then cast + grad = _full(0, values.get_device(), torch.float32, shape) + assert isinstance(grad, TensorBox) + grad.realize() + x_size = grad.get_size() + values = to_dtype(values, grad.get_dtype()) + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device()) + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=True, + ) + + values = expand(values, expected_vals_size) + device = grad.get_device() + assert device is not None + scatter = Scatter( + device=device, + dtype=grad.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add", + ) + + buffer = ComputedBuffer( + name=grad.data.data.name, # type: ignore[attr-defined] + layout=MutationLayoutSHOULDREMOVE(grad), + data=scatter, + ) + return buffer + + +def get_fwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults +) -> list[Optional[ComputedBuffer]]: + subgraph_buffer = ( + # pyrefly: ignore [bad-assignment] + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + # pyrefly: ignore [bad-assignment] + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + # pyrefly: ignore [not-iterable] + return [*subgraph_buffer, *mask_graph_buffer] + + +def build_subgraph_module_buffer( + args: list[Union[TensorBox, ShapeAsConstantBuffer]], + graph_module: torch.fx.GraphModule, +) -> SubgraphResults: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph. Contains both fixed and lifted inputs. + subgraph: The Subgraph ir for which to produce the output node + """ + # This one we gotta keep lazy + from ...subgraph_lowering import PointwiseSubgraphLowering + + pw_subgraph = PointwiseSubgraphLowering( + graph_module, + root_graph_lowering=V.graph, + allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]), + additional_lowerings={ + torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering + }, + ) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*args) + + def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: + if output_buffer is None: + return None + if isinstance(output_buffer, ComputedBuffer): + # These nodes are coming from the output of zeros_and_scatter + return output_buffer + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + device = output_buffer.data.get_device() + assert device is not None + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=device, + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) + + +def build_subgraph_buffer( + args: list[Union[TensorBox, ShapeAsConstantBuffer]], subgraph: Subgraph +) -> SubgraphResults: + return build_subgraph_module_buffer(args, subgraph.graph_module) + + +def maybe_realize(args: list[Optional[IRNode]]): + """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" + return tree_map( + lambda x: ( + realize_inputs(x) + if x is not None and not isinstance(x, sympy.Symbol) + else x + ), + args, + ) + + +def freeze_irnodes(tree: Any) -> Any: + """Freeze layouts for every IRNode contained in a pytree.""" + + if tree is None: + return None + + def _freeze(node: IRNode) -> IRNode: + try: + node.freeze_layout() + except NotImplementedError: + pass + return node + + return tree_map_only(IRNode, _freeze, tree) + + +def create_placeholder( + name: str, + dtype: torch.dtype, + device: torch.device, + size: Optional[list[int]] = None, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer( + name=name, + layout=FixedLayout( + device, + dtype, + size if size else [], + FlexibleLayout.contiguous_strides(size) if size else [], + ), + ) + return TensorBox.create(input_buffer) + + +def construct_strides( + sizes: Sequence[_IntLike], + fill_order: Sequence[int], +) -> Sequence[_IntLike]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len(fill_order), ( + "Length of sizes must match the length of the fill order" + ) + strides: list[_IntLike] = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride: _IntLike = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def infer_dense_strides( + size: Sequence[_IntLike], + orig_strides: Sequence[_IntLike], +): + """This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp + + Args: + size: The size of the output tensor + orig_strides: The strides of the input tensor + Returns: + List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation. + The returned strides follow the same stride propagation rules as TensorIterator. This matches + The behavior of empty_like() + """ + fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env) + return construct_strides(size, fill_order) + + +def create_indices_fake(x) -> torch.Tensor: + """Create a fake indices that is used for autotuning.""" + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device()) + indices = indices.expand(size).contiguous() + return indices + + +def create_num_blocks_fake_generator(sparse_indices): + """Create a fake num_blocks that is used for autotuning. + + The idea here is that we need to create a real tensor with real data + that's representative for benchmarking. + For example, returning all zeros for the `kv_num_blocks` input would mean + that we are computing 0 blocks for each row, which would provide bogus + autotuning results. + + In this case, we choose to use min(16, max_block) blocks, because I + (Horace) think it'll probably result in pretty representative performance. + If it's too short then prefetching won't help. If it's too long then + autotuning will take longer for no good reason. + """ + + def create_num_blocks_fake(x) -> torch.Tensor: + num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1]) + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + return torch.full( + size, + num_blocks_for_autotuning, + dtype=x.get_dtype(), + device=x.get_device(), + ) + + return create_num_blocks_fake + + +def contiguous_last_dim(x): + """Ensure that realized IR node has a contiguous stride in the last dimension.""" + strides = x.maybe_get_stride() + if strides and strides[-1] != 1: + contiguous_stride_order = list(reversed(range(len(x.get_size())))) + return ExternKernel.require_stride_order(x, contiguous_stride_order) + return x + + +def set_head_dim_values( + kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars +): + """ + Mutates kernel options, adding head dimension calculations. + + Args: + kernel_options: Dictionary to populate with options + qk_head_dim: Query/Key head dimension + v_head_dim: Value head dimension + graph_sizevars: Graph size variables object with guard_int method + + """ + # QK dimensions + qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim) + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static) + kernel_options.setdefault( + "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static) + ) + + # V dimensions + v_head_dim_static = graph_sizevars.guard_int(v_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static) + kernel_options.setdefault( + "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static) + ) + + # Safety flag + kernel_options.setdefault( + "SAFE_HEAD_DIM", + is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static), + ) + + +def is_power_of_2(n): + return n != 0 and ((n & (n - 1)) == 0) + + +def next_power_of_two(n): + if n <= 0: + return 1 + return 2 ** math.ceil(math.log2(n)) + + +_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR) + + +# Template strings have been moved to templates/common.py.jinja diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_attention.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d36b8d56cc711504dad6f9071453e887e23e1a83 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_attention.py @@ -0,0 +1,977 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel""" + +from __future__ import annotations + +import logging +import math +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, cast, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.nn.attention.flex_attention import _Backend + +from ...ir import ComputedBuffer, ExternKernel, FixedLayout, TensorBox +from ...lowering import empty, empty_strided, lowerings, register_lowering +from ...select_algorithm import ( + autotune_select_algorithm, + SymbolicGridFn, + TritonTemplate, +) +from .common import ( + build_subgraph_buffer, + create_indices_fake, + create_num_blocks_fake_generator, + create_placeholder, + freeze_irnodes, + get_fwd_subgraph_outputs, + infer_dense_strides, + load_flex_template, + maybe_realize, + set_head_dim_values, + SubgraphResults, +) +from .flex_cpu import lower_cpu +from .flex_decoding import _use_flex_decoding, create_flex_decoding_kernel +from .flex_flash_attention import ( + _use_flex_flash_attention, + _use_flex_flash_attention_backward, + create_flex_flash_attention_backward_kernel, + create_flex_flash_attention_kernel, +) + + +if TYPE_CHECKING: + from ...template_heuristics.triton import FlexBwDConfig, FlexConfig + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +Expr = sympy.Expr + + +def _sanitize_kernel_options_for_triton( + kernel_options: dict[str, Any], +) -> tuple[dict[str, Any], _Backend]: + """We always strip quotes around str values, we only need this in lowering, so we pop it here + to avoid passing to triton constexpr dict + """ + sanitized = dict(kernel_options) + backend = cast(_Backend, sanitized.pop("BACKEND", "AUTO")) + return sanitized, backend + + +@SymbolicGridFn +def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): + """How is this kernel parallelized? + We create a grid of (ceil_div(n_queries, query_block_size), batch_size, num_heads) + Each block is responsible for iterating over blocks of keys and values calculating + the final attention output. + """ + return (cdiv(num_queries, meta["BLOCK_M"]), batch_size, q_heads) + + +def get_float32_precision(): + if ( + ( + torch.backends.cuda.matmul.fp32_precision == "ieee" + if torch.backends.cuda.matmul.fp32_precision != "none" + else torch.get_float32_matmul_precision() == "highest" + ) + or torch.version.hip + or torch.mtia.is_available() + ): + return "'ieee'" + else: + return "'tf32'" + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, + source=load_flex_template("flex_attention") + + load_flex_template("utilities") + + load_flex_template("common"), +) + + +@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) +def flex_attention( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options: dict[str, Any], + score_mod_other_buffers, + mask_mod_other_buffers, +): + """The main lowering for the flex_attention hop + This can currently lower to one of 3 templates: + 1. Base Triton Template + 2. Flex Decode Triton Template + 3. Cpu specific CPP template + """ + if query.get_device().type == "cpu": + return lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + # below is cuda path if device is not cpu + # tl.dot does not support embedding size less than 16 + small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16)) + small_dv = V.graph.sizevars.evaluate_expr(sympy.Lt(value.get_size()[-1], 16)) + if small_dqk or small_dv: + raise NotImplementedError( + f"NYI: embedding dimension of the query, key, and value must be " + f"at least 16 but got E={query.get_size()[-1]} and Ev={value.get_size()[-1]}" + ) + + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + freeze_irnodes(subgraph_buffer) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + freeze_irnodes(mask_graph_buffer) + + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + enable_gqa = V.graph.sizevars.evaluate_expr( + sympy.Ne(query.get_size()[1], key.get_size()[1]), + ) + + can_use_decode = _use_flex_decoding( + query, kv_indices, value, kernel_options, enable_gqa + ) + use_decode = (backend == "TRITON_DECODE") or (backend == "AUTO" and can_use_decode) + + if backend == "TRITON_DECODE" and not can_use_decode: + raise RuntimeError( + "BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used for this input. " + "flex_decoding is only available for short sequence lengths with specific configurations." + ) + + if use_decode: + return create_flex_decoding_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + if _use_flex_flash_attention( + subgraph, + mask_graph, + kernel_options, + num_score_mod_placeholders=len(placeholder_inps), + backend=backend, + ): + return create_flex_flash_attention_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + mask_graph=mask_graph, + subgraph=subgraph, + ) + + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + freeze_irnodes(score_mod_other_buffers) + freeze_irnodes(mask_mod_other_buffers) + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), ( + "Query length must be greater than 0" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), ( + "Key length must be greater than 0" + ) + + B = Bq + + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # NB it is okay that the v_head_dim is different + # We are using these to match fill order of the output. + q_strides = query.get_stride() + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, q_strides) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = [B, Hq, seq_len_q] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + max_scores = empty_strided( + logsumexp_shape, # Same shape as logsumexp + None, + dtype=torch.float32, # The max scores are always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA broadcast factor. + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is None if partial blocks are not computed + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) + configs: list[FlexConfig] = V.choices.get_flex_attention_fwd_configs( + head_dim, dtype, query.get_device().type + ) + + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + original_kernel_options = kernel_options.copy() + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + cur_kernel_options = original_kernel_options.copy() + # Performance tuning + # Triton parameters + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # USE TMA = false by default + cur_kernel_options.setdefault("USE_TMA", False) + + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + if ( + cur_kernel_options["SPARSE_KV_BLOCK_SIZE"] % cur_kernel_options["BLOCK_N"] + != 0 + or cur_kernel_options["SPARSE_Q_BLOCK_SIZE"] % cur_kernel_options["BLOCK_M"] + != 0 + ): + if len(configs) == 1: + raise ValueError( + f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We " + f"got Q_BLOCK_SIZE={cur_kernel_options['SPARSE_Q_BLOCK_SIZE']} and " + f"KV_BLOCK_SIZE={cur_kernel_options['SPARSE_KV_BLOCK_SIZE']}." + ) + continue + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + max_scores, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + max_scores, + ], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + max_scores, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + out = autotune_select_algorithm( + "flex_attention", + choices, + # Need to filter out symbols since there is an invariant + # that all input_nodes are of type IRNode + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + out.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + out.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (out, logsumexp, max_scores) + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +@SymbolicGridFn +def flex_attention_backward_grid( + batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta, *, cdiv +): + """How is this kernel parallelized? + We create a grid of (ceil_div(n_queries, query_block_size) * heads_ratio + ceil_div(n_kv, kv_block_size), batch_size, kv_heads) + Currently this is only parallelizing over batch* kv_heads, but we can, and want to + parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). + To do this will either require atomic updates to some grad values or to have a two pass kernel design. + """ + return ( + cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) + + cdiv(num_key_value, meta["BLOCK_N1"]), + batch_size, + kv_heads, + ) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=load_flex_template("flex_backwards") + load_flex_template("utilities"), +) + + +def validate_joint_graph(joint_graph: torch.fx.Graph): + """We do some pre lowering graph checks in order to raise nicer error messages""" + for node in joint_graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.flex_lib.zeros_and_scatter.default + ): + for user in node.users: + if user.op != "output": + raise NotImplementedError( + "Using multiple indexing operations on the same tensor that requires gradients " + "in a score_mod function is not currently supported. " + "This typically happens when indexing the same tensor multiple times, like:\n\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias[kv_idx] # bias used twice!\n\n" + "A valid workaround is to clone() the tensors that will be indexed multiple times. For example:\n\n" + " bias1 = bias.clone()\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias1[kv_idx]\n\n" + "Note that this solution will use additional memory." + ) + return + + +@dataclass(frozen=True) +class JointOutputResult: + """Results from processing joint outputs.""" + + grad_input: ComputedBuffer + captured_grads_compute: list[ComputedBuffer] + captured_grads: list[Optional[TensorBox]] + mutated_grads: list[TensorBox] + + +def process_joint_outputs( + all_joint_outputs: SubgraphResults, num_placeholders: int +) -> JointOutputResult: + """Process joint outputs and extract various buffers needed for lowering + + Args: + all_joint_outputs: List of all the outputs from build_subgraphs + num_placeholders: The number of placeholder inputs, used to skip over unused backward compute buffers + + Returns: + JointOutputResult containing processed buffers and gradients + """ + assert isinstance(all_joint_outputs, list) + assert all_joint_outputs[0] is not None, ( + "joint_subgraph_buffer is None - this is a bug!" + ) + + joint_buffer = all_joint_outputs[0] + other_grads = all_joint_outputs[num_placeholders - 1 :] + + # outer_grads has the structure: Len(other_buffer_grads) if buffer doesn't require grad than it will be None + # We only grab the buffers that require grad for inlining into kernel + grads_compute = [buf for buf in other_grads if buf is not None] + + def get_out(buf): + if buf is None: + return None + assert isinstance(buf, ComputedBuffer) + assert buf.name is not None + return TensorBox.create(V.graph.get_buffer(buf.name)) + + grads_out = [get_out(x) for x in other_grads] + mutated_grads = [buf for buf in grads_out if buf is not None] + + return JointOutputResult( + grad_input=joint_buffer, + captured_grads_compute=grads_compute, + captured_grads=grads_out, + mutated_grads=mutated_grads, + ) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + """Lowering for the flex_attention_backward op in triton""" + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + ( + query, + key, + value, + logsumexp, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + logsumexp, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + device = query.get_device() + dtype = query.get_dtype() + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: + kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) + for name, dtype in [ + ("score", dtype), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + fw_subgraph_buffer = build_subgraph_buffer( + fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph + ) + freeze_irnodes(fw_subgraph_buffer) + + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("grad_score_mod", dtype, device) + ] + # Sometimes we have weird unused nodes here + joint_graph.graph_module.graph.eliminate_dead_code() + + # It is hard to raise nice errors for some joint graphs during subgraph lowering + # This lets us do some checks before attempting to lower + validate_joint_graph(joint_graph.graph_module.graph) + + all_joint_outputs = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), + joint_graph, + ) + freeze_irnodes(all_joint_outputs) + + joint_outputs = process_joint_outputs( + all_joint_outputs, len(joint_placeholder_inps) + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + freeze_irnodes(mask_graph_buffer) + + if _use_flex_flash_attention_backward( + fw_graph, + mask_graph, + backend=backend, + ): + return create_flex_flash_attention_backward_kernel( + query, key, value, out, logsumexp, grad_out, scale, kernel_options + ) + + # Construct layout with stride order matching K + key_size = [Bq, Hkv, seq_len_kv, qk_head_dim] + key_strides = infer_dense_strides(key_size, key.get_stride()) + + layout_broadcasted_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key_size, + stride=[sympy.sympify(s) for s in key_strides], + ) + + # Create delta which will is needed for the bwd's kernel + grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2)) + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + delta = lowerings[aten.sub](delta, grad_lse_exp2) + delta = ExternKernel.require_contiguous(delta) + + grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta]) + + # # see NOTE:[TritonTemplates with multiple outputs] + query_size = [Bq, Hq, seq_len_q, qk_head_dim] + grad_query_strides = infer_dense_strides(query_size, query.get_stride()) + grad_query = empty_strided( + query_size, + stride=[sympy.sympify(s) for s in grad_query_strides], + dtype=query.get_dtype(), + device=query.get_device(), + ) + + # Construct output layout with stride order matching value + value_size = [Bq, Hkv, seq_len_kv, v_head_dim] + value_strides = infer_dense_strides(value_size, value.get_stride()) + + broadcasted_grad_value = empty_strided( + value_size, + stride=[sympy.sympify(s) for s in value_strides], + dtype=value.get_dtype(), + device=value.get_device(), + ) + + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA factor + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed. + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = ( + empty(0, device=query.get_device()) for _ in range(4) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) + configs: list[FlexBwDConfig] = V.choices.get_flex_attention_bwd_configs( + head_dim, dtype, query.get_device().type + ) + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + original_kernel_options = kernel_options.copy() + + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_n1 != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m1 != 0 + or SPARSE_KV_BLOCK_SIZE % conf.block_n2 != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m2 != 0 + ): + continue + + # Performance tuning + # Triton heuristics + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for backward kernels options and delete forward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("bwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("fwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m1) + cur_kernel_options.setdefault("BLOCK_N1", conf.block_n1) + cur_kernel_options.setdefault("BLOCK_M2", conf.block_m2) + cur_kernel_options.setdefault("BLOCK_N2", conf.block_n2) + + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + **cur_kernel_options, + ) + inputs_for_autotuning = ( + # pyrefly: ignore [unsupported-operation] + [ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + + joint_outputs.mutated_grads + ) + input_gen_fns = { + 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks + 9: create_indices_fake, + 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks + 11: create_indices_fake, + 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks + 13: create_indices_fake, + 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks + 15: create_indices_fake, + } + + broadcasted_grad_key = autotune_select_algorithm( + "flex_attention_backward", + choices, + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout_broadcasted_k, + input_gen_fns=input_gen_fns, + ) # [Bq, Hkv, seq_len_kv, k_head_dim] + + # need subgraph inputs and outputs to analyze all symints used in flex attention + broadcasted_grad_key.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + broadcasted_grad_key.data.data.subgraph_outs = get_bwd_subgraph_outputs( + fw_subgraph_buffer, mask_graph_buffer, joint_outputs + ) + + if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)): + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + else: + assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. " + f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} " + f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" + ) + grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) + grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) + + return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads)) + + +def get_bwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, + mask_graph_buffer: SubgraphResults, + joint_outputs: JointOutputResult, +) -> list[Optional[Union[ComputedBuffer, TensorBox]]]: + subgraph_buffer = ( + # pyrefly: ignore [bad-assignment] + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + # pyrefly: ignore [bad-assignment] + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + joint_output_buffers = [ + joint_outputs.grad_input, + *joint_outputs.captured_grads_compute, + *joint_outputs.captured_grads, + *joint_outputs.mutated_grads, + ] + + # pyrefly: ignore [not-iterable] + return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_cpu.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..6987e64546fe3503b6a7b7e9bb1a44e72fbb2660 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_cpu.py @@ -0,0 +1,339 @@ +# mypy: allow-untyped-defs +"""CPU-specific implementations for flex attention""" + +import copy +import os +import sys +from typing import Any + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + +from ...codegen.cpp_flex_attention_template import CppFlexAttentionTemplate +from ...ir import Buffer, FixedLayout, TensorBox +from ...select_algorithm import autotune_select_algorithm +from .common import ( + build_subgraph_buffer, + build_subgraph_module_buffer, + contiguous_last_dim, + create_placeholder, + get_fwd_subgraph_outputs, + infer_dense_strides, + maybe_realize, +) + + +def check_cpu_supported(): + requires_avx2_on_cpu = ( + torch.cpu._is_avx2_supported() and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ) + supported = ( + requires_avx2_on_cpu + and not torch.xpu.is_available() + and sys.platform != "darwin" + ) + return supported + + +def lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + """CPP based template for flex attention for x86 CPUs""" + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + if kernel_options["OUTPUT_LOGSUMEXP"]: + raise NotImplementedError( + "torch.compile on CPU only supports inference and `return_lse` is not supported yet." + ) + if not check_cpu_supported(): + raise NotImplementedError( + "torch.compile on current platform is not supported for CPU." + ) + + fake_buffers: list[Buffer] = [] # noqa: F821 + + # [Note] Handle the case where the split sizes are not statically known. + # The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime. + # We use symbols to represent them during the compilation here. + # They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in + # the modification function of the CppFlexAttentionTemplate class. + cur_qSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + shape_env = V.graph.sizevars.shape_env + + # We don't know the concrete value of cur_qSplitSize and cur_kvSplitSize during the compilation. + # Mark symbols > 1 to ensure broadcasting is always applied. + # This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`. + shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo) + shape_env.var_to_range[cur_kvSplitSize] = ValueRanges(2, int_oo) + + score_dtype = torch.float + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + if subgraph_buffer is not None: + if isinstance(subgraph_buffer, list): + for _buf in subgraph_buffer: + if _buf is not None: + _buf.freeze_layout() + else: + subgraph_buffer.freeze_layout() + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + + # The original mask_graph works on a scalar and only includes + # the logic of calculating the mask value. + # We need to add the logic of applying the mark to the qk_data tensor + # into the graph for the later codegen of this part. + # Example: + # mask_graph: + # def mask_fn(b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # return mask + # The converted_mask_graph should be: + # def converted_mask_fn(qk_data, b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf"))) + # return qk_data + def convert_mask_graph_module(mask_graph): + gm = copy.deepcopy(mask_graph.graph_module) + graph = gm.graph + # Add qk_data as the first input + with graph.inserting_before(next(iter(graph.nodes))): + qk_data_node = graph.placeholder("qk_data") + + # Find the node that returns the mask + output_node = None + for node in graph.nodes: + if node.op == "output": + output_node = node + break + + # Get the mask node + assert output_node is not None + mask_node = output_node.args[0] + + size_node = [cur_qSplitSize, cur_kvSplitSize] + # Create a new node for torch.full + with graph.inserting_after(mask_node): + full_node = graph.call_function( + torch.full, + args=(size_node, -float("inf")), + kwargs={"dtype": score_dtype}, + ) + + # Create a new node for torch.where + with graph.inserting_after(full_node): + where_node = graph.call_function( + torch.ops.aten.where, args=(mask_node, qk_data_node, full_node) + ) + + # Update the output node to return the result of torch.where + output_node.args = (where_node,) + + graph.lint() + converted = torch.fx.GraphModule(gm, graph) + return converted + + converted_mask_graph_module = convert_mask_graph_module(mask_graph) + + mask_graph_buffer = build_subgraph_module_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), + converted_mask_graph_module, + ) + + # Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel. + pending = V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols + V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols = [ + x for x in pending if x not in (cur_qSplitSize, cur_kvSplitSize) + ] + + buffer_list = ( + placeholder_inps + + list(score_mod_other_buffers) + + mask_graph_placeholder_inps + + list(mask_mod_other_buffers) + ) + for item in buffer_list: + if isinstance(item, TensorBox): + fake_buffers.append(item.data.data) # type: ignore[attr-defined] + + # CPU kernel requires last dim to be contiguous + query, key, value = map(contiguous_last_dim, [query, key, value]) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + if len(OrderedSet([query.get_name(), key.get_name(), value.get_name()])) != 3: + raise NotImplementedError( + "Unsupported for now if query, key, value are the same buffer." + ) + if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]: + raise NotImplementedError( + "`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. " + f"Found input tensors are `{query.get_dtype()}`." + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + B = Bq + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, query.get_stride()) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + _choices: list[Any] = [] + input_nodes = [query, key, value, kv_num_blocks, kv_indices] + if not full_kv_num_blocks: + no_full_kv_block = True + else: + no_full_kv_block = False + input_nodes += [full_kv_num_blocks] + input_nodes += [full_kv_indices] + has_other_buffer = False + kernel_input_name_to_buffer = {} + if score_mod_other_buffers or mask_mod_other_buffers: + has_other_buffer = True + + for prefix, buffers in [ + ("score_others", score_mod_other_buffers), + ("mask_others", mask_mod_other_buffers), + ]: + kernel_input_name_to_buffer.update( + {f"{prefix}_{i}": buf for i, buf in enumerate(buffers)} + ) + input_nodes += [ + value + for value in kernel_input_name_to_buffer.values() + if not isinstance(value, sympy.Symbol) + ] + + skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) + ), ( + "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + ) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) + ), ( + "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + ) + CppFlexAttentionTemplate.add_choices( + choices=_choices, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=None if skip_mask_score else subgraph_buffer, + mask_mod=None if skip_mask_score else mask_graph_buffer, + kv_block_size=SPARSE_KV_BLOCK_SIZE, + q_block_size=SPARSE_Q_BLOCK_SIZE, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len(score_mod_other_buffers), + len_mask_other=len(mask_mod_other_buffers), + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=(cur_qSplitSize, cur_kvSplitSize), + ) + inputs_for_autotuning = [ + query, + key, + value, + ] + res = autotune_select_algorithm( + "flex_attention", + _choices, + inputs_for_autotuning, + layout, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + res.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + res.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (res,) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_decoding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..37113a1d82a8455eca455b6d5e077fa06b952f5a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_decoding.py @@ -0,0 +1,436 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" + +from typing import Any + +import sympy + +import torch +from torch._inductor.virtualized import V + +from ... import ir +from ...ir import FixedLayout, FlexibleLayout +from ...lowering import empty, empty_strided, lowerings +from ...runtime.runtime_utils import is_power_of_2, next_power_of_2 +from ...select_algorithm import ( + autotune_select_algorithm, + SymbolicGridFn, + TritonTemplate, +) +from .common import ( + create_indices_fake, + create_num_blocks_fake_generator, + freeze_irnodes, + get_fwd_subgraph_outputs, + load_flex_template, + maybe_realize, + set_head_dim_values, +) + + +aten = torch.ops.aten +prims = torch.ops.prims + + +def _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa) -> bool: + """Decide which kernel to use, return true if use flex decoding kernel. + Note: + Since the number of splits is calculated based of the number of batch and head dims + we need to ensure that the batch and head dims are statically known. Otherwise we just + use the main flex_attention kernel. + """ + force_flex = kernel_options.get("FORCE_USE_FLEX_ATTENTION", False) + short_query_length = V.graph.sizevars.evaluate_expr( + sympy.Lt(query.get_size()[-2], 128) + ) + non_zero_length = V.graph.sizevars.evaluate_expr(sympy.Gt(query.get_size()[-2], 0)) + static_batch = isinstance(query.get_size()[0], (int, sympy.Integer)) + static_num_heads = isinstance(query.get_size()[1], (int, sympy.Integer)) + if enable_gqa: + # in the current flex decoding triton kernel, grouped query heads for the + # same kv head are handled by the same block. So it's hard to support different + # kv num blocks for grouped query heads. We just fall back to main flex_attention + # kernel where each query head is handled by a separate block. + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Eq(kv_indices.get_size()[1], 1) + ) + else: + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Or( + sympy.Eq(kv_indices.get_size()[1], 1), + sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]), + ) + ) + + Hq = query.get_size()[1] + Hkv = value.get_size()[1] + ratio = Hq // Hkv + + pw_of_two = V.graph.sizevars.guard_or_false( + sympy.And(sympy.Gt(ratio, 0), sympy.Eq(ratio & (ratio - 1), 0)) + ) + + return ( + not force_flex + and not kernel_options.get("OUTPUT_MAX", False) + and short_query_length + and static_batch + and static_num_heads + and non_zero_length + and valid_block_mask_num_heads + and pw_of_two + ) + + +@SymbolicGridFn +def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta): + """How is this kernel parallelized? + We create a grid of (batch_size * kv_heads, SPLIT_KV, 1) + Each block is responsible for iterating over blocks of keys and values calculating + the local output for their tile of keys and values over all full length of query. + groups of SPLIT_KV blocks then combine their output to produce the final result. + """ + + return (batch_size * kv_heads, meta["SPLIT_KV"], 1) + + +flex_decoding_template = TritonTemplate( + name="flex_decoding", + grid=flex_decoding_grid, + source=load_flex_template("flex_decode") + + load_flex_template("utilities") + + load_flex_template("common"), +) + + +def get_split_k(B: int, H: int, Mk: int) -> int: + if torch.xpu.is_available(): + num_SM = torch.xpu.get_device_properties("xpu").gpu_subslice_count + else: + num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers" + split_k = num_SM // bh * 2 # Each SM should at least get one block. + # TODO: workload evening at runtime for splits fully masked out. + # Before we have runtime workload evening, assign 2 splits per SM. + split_k = max(split_k, 1) + + return split_k + + +def create_flex_decoding_kernel(*args, **kwargs): + """Flex decode lowering that is optimized for small Q_LEN and GQA packing""" + ( + query, + key, + value, + block_mask, + scale, + kernel_options, + score_mod_subgraph, + mask_mod_subgraph, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, # full_kv_num_blocks, + full_kv_indices, # full_kv_indices, + _, # q_num_blocks + _, # q_indices + _, # full_q_num_blocks, + _, # full_q_indices, + _, # SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + _, + ) = block_mask + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + B = Bq + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v + for k, v in kernel_options.items() + } + + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: + kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) + + # Calculate GQA head sharing + gqa_shared_heads = Hq // Hkv + if not is_power_of_2(gqa_shared_heads): + raise ValueError( + "Number of shared query heads sharing the same KV head must be power of 2. " + ) + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + # Create a plackeholder full block list in case it is empty + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + freeze_irnodes(score_mod_other_buffers) + freeze_irnodes(mask_mod_other_buffers) + + choices: list[Any] = [] + dtype = key.get_dtype() + head_dim = V.graph.sizevars.guard_int(key.get_size()[-1]) + configs = V.choices.get_flex_decode_configs( + head_dim, dtype, query.get_device().type + ) + + # TODO: fix autotuning. + + kernel_options.setdefault("SM_SCALE", scale) + kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv)) + MAX_SPLIT_KV = kernel_options["SPLIT_KV"] + + # create config dependent intermediate buffers + buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim] + buf_ML_shape = buf_ACC_shape[:-1] + buf_M = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + buf_L = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + + layout_acc = FixedLayout( + query.get_device(), + torch.float32, + buf_ACC_shape, + FlexibleLayout.contiguous_strides(buf_ACC_shape), + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + kernel_options.setdefault( + "BLOCK_M", + ( + # m + # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0)) + # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin + max( + next_power_of_2( + V.graph.sizevars.size_hint( + seq_len_q, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + * gqa_shared_heads + ), + 1 if torch.xpu.is_available() else 16, + ) + ), + ) + + query = ir.ExternKernel.realize_input(query) + stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride() + + # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D] + gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim) + gqa_query_stride = ( + stride_b, + stride_hq * gqa_shared_heads, + stride_hq, + stride_seq_len_q, + stride_qk_head_dim, + ) + query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride) + + V.graph.sizevars.check_leq( + seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"]) + ) + + kernel_options.setdefault( + "SAFE_M_BOUNDARY", + ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0, + ) + # TODO: This feels sketchy + kernel_options.setdefault("SAFE_N_BOUNDARY", True) + # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + + original_kernel_options = kernel_options.copy() + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if SPARSE_KV_BLOCK_SIZE % conf.block_n != 0: + continue + + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + # Performance tuning + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Set default to False + cur_kernel_options.setdefault("USE_TMA", False) + + # Add ROCm-specific parameters if they exist in the config + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_decoding_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout_acc, + subgraphs=[ + score_mod_subgraph, + mask_mod_subgraph, + ], + mutated_inputs=[buf_M, buf_L], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + + filtered_score_mod_buffers = [ + buf for buf in score_mod_other_buffers if not isinstance(buf, sympy.Symbol) + ] + filtered_mask_mod_buffers = [ + buf for buf in mask_mod_other_buffers if not isinstance(buf, sympy.Symbol) + ] + + inputs_for_flex_decoding = ( + # pyrefly: ignore [unsupported-operation] + [ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + filtered_score_mod_buffers + + filtered_mask_mod_buffers + ) + + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + buf_ACC = autotune_select_algorithm( + "flex_decoding", + choices, + inputs_for_flex_decoding, + layout_acc, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + buf_ACC.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + buf_ACC.data.data.subgraph_outs = get_fwd_subgraph_outputs( + score_mod_subgraph, mask_mod_subgraph + ) + + # Reduction + + g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0] + # See [Note] Handle fully masked out rows: + # g_M Is the global max among split kv blocks. + masked_rows = lowerings[aten.eq](g_M, -float("inf")) + adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) + alpha = lowerings[aten.exp2](adj_M) + + buf_L = lowerings[aten.mul](buf_L, alpha) + g_L = lowerings[aten.sum](buf_L, axis=1) + masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1) + g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L) + logsumexp = lowerings[aten.log2](g_L) + logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1)) + + alpha_unseq = lowerings[aten.unsqueeze](alpha, 4) + buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq) + output = lowerings[aten.sum](buf_ACC, axis=1) + L_unseq = lowerings[aten.unsqueeze](g_L, 3) + output = lowerings[aten.div](output, L_unseq) + output = lowerings[prims.convert_element_type](output, query.get_dtype()) + + return ( + output, + logsumexp, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_flash_attention.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..05d1290f0ab49f55dbe2b4ed331f3f408c772831 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -0,0 +1,491 @@ +# mypy: allow-untyped-defs +"""Call into flash-attention 4 for flexattention""" + +import functools +import importlib +from collections.abc import Callable, Sequence +from contextlib import contextmanager +from typing import Any, Literal, Optional + +import sympy +from sympy import Expr, Integer + +import torch +from torch.fx import GraphModule +from torch.utils._sympy.functions import Identity + +from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox +from ...lowering import empty_strided +from .common import infer_dense_strides, load_flex_template, SubgraphResults + + +aten = torch.ops.aten +prims = torch.ops.prims + + +@functools.lru_cache(maxsize=1) +def ensure_flash_available() -> bool: + """Check if flash-attn is importable; cache the result for reuse. + + Call ensure_flash_available.cache_clear() after installing flash-attn + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("flash_attn.cute") is not None # type: ignore[attr-defined] + except ImportError: + return False + + +from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate + + +flash_attention_cutedsl_template = CuteDSLTemplate( + name="flash_attention_cutedsl", source=load_flex_template("flash_attention") +) +flash_attention_backward_cutedsl_template = CuteDSLTemplate( + name="flash_attention_backward_cutedsl", + source=load_flex_template("flash_attention_backward"), +) + + +def _fixed_indexer_cute( + size: Sequence[int], + stride: Optional[Sequence[int]] = None, + offset: Expr = Integer(0), +) -> Callable[[Sequence[Expr]], Expr]: + """ + Colexicographic indexer for CuteDSL - matches CuTe's coordinate interpretation. + + CuTe interprets linear indices in colexicographic (column-major) order, + whereas Inductor's default _fixed_indexer uses lexicographic (row-major) order. + + For size=[4, 128] with index=[b, q_idx]: + - Lexicographic: b*128 + q_idx*1 + - Colexicographic: b*1 + q_idx*2 + + CuTe then applies the tensor's actual memory strides to get the correct offset. + """ + + def indexer(index: Sequence[Expr]) -> Expr: + assert offset == Integer(0), "Offset not supported for colexicographic indexing" + if not index: + return Integer(0) + + result = index[0] + runner = size[0] + + for idx, sz in zip(index[1:], size[1:], strict=True): + result = result + runner * Identity(idx) + runner = runner * sz + + return result + + return indexer + + +@contextmanager +def patch_fixed_layout_indexer_for_cutedsl(): + """ + Temporarily swap FixedLayout.make_indexer so CuteDSL sees colexicographic indexing. + + Note [CuteDSL indexer patch]: + Flex flash attention only supports a limited set of IR ops (pointwise, reads, no stores), + so temporarily changing the indexing order is safe for the kernels we emit today. + TODO(dynamic shapes): Reconfirm once flex flash attention supports dynamic shapes. + """ + original_make_indexer = FixedLayout.make_indexer + + def cutedsl_make_indexer(self): + return _fixed_indexer_cute(self.size, self.stride, self.offset) + + FixedLayout.make_indexer = cutedsl_make_indexer # type: ignore[assignment] + try: + yield + finally: + FixedLayout.make_indexer = original_make_indexer # type: ignore[assignment] + + +def wrap_choice_render_with_cutedsl_indexer(choice: Any) -> None: + """ + Wrap a template choice's kernel render to apply CuteDSL indexer patching. + + See Note [CuteDSL indexer patch]: + This wrapper allows the template to construct its closures normally, then + scopes the indexer patch to the actual render call that emits the kernel. + This ensures CuteDSL templates see colexicographic indexing while preserving + the template's setup logic. + """ + original_make_kernel_render = choice.make_kernel_render + + def make_kernel_render_with_patch(*args, **kwargs): + render_kernel, render = original_make_kernel_render(*args, **kwargs) + # Let the template construct its closures, then scope the indexer patch + # to the actual render call that emits the kernel + render_with_patch = patch_fixed_layout_indexer_for_cutedsl()(render) + return render_kernel, render_with_patch + + choice.make_kernel_render = make_kernel_render_with_patch + + +def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int): + """Check if any of the input buffers (beyond the score mod placeholders) require gradients.""" + inputs = [] + for node in graph_module.graph.nodes: + if node.op == "placeholder": + inputs.append(node) + if len(inputs) <= num_score_mod_placeholders: + return False + + def requires_grad(n): + tensor_meta = n.meta.get("tensor_meta") + return tensor_meta.requires_grad if tensor_meta is not None else False + + return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:]) + + +def is_trivial_score_graph(graph_module: GraphModule) -> bool: + """Backwards currently doesn't support score_mods, match against identity""" + graph = graph_module.graph + nodes = list(graph.nodes) + placeholders = [n for n in nodes if n.op == "placeholder"] + output = [n for n in nodes if n.op == "output"] + assert len(output) == 1, "Got graph w/ multiple outputs" + output_val = output[0].args[0] + # The identity graph just sends the score straight through + return output_val == placeholders[0] + + +def is_trivial_mask_graph(graph_module: GraphModule) -> bool: + """Mask graph is trivial when it only gates via the default full op.""" + graph = graph_module.graph + nodes = list(graph.nodes) + placeholders = [n for n in nodes if n.op == "placeholder"] + output = [n for n in nodes if n.op == "output"] + assert len(output) == 1, "Got graph w/ multiple outputs" + output_val = output[0].args[0] + + # mask mod graph is empty if we have 4 inputs and full_default output + return len(placeholders) == 4 and output_val.target is torch.ops.aten.full.default + + +@functools.lru_cache(maxsize=1) +def _supports_nontrivial_mask_graphs() -> bool: + """Currently only supported on Hopper (SM90) GPUs.""" + return torch.cuda.get_device_capability()[0] in [9, 10] + + +def _can_use_flex_flash_attention( + subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int +) -> tuple[bool, str]: + """Check if flex flash attention can be used for the given inputs. + + Returns: + tuple: (can_use, reason) where reason explains why it can't be used if can_use is False + """ + if not ensure_flash_available(): + return False, "CUTE flash attention library is not available" + + if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders): + return ( + False, + "Input buffers require gradients (not supported by flash attention)", + ) + mask_trivial = is_trivial_mask_graph(mask_graph.graph_module) + + if mask_trivial: + return True, "" + + if not _supports_nontrivial_mask_graphs(): + return ( + False, + "NYI: Non-trivial mask graphs only supported on Hopper (SM90) for flash attention", + ) + + return True, "" + + +def _use_flex_flash_attention( + subgraph: Subgraph, + mask_graph: Subgraph, + kernel_options: dict[str, Any], + num_score_mod_placeholders: int, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], +) -> bool: + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False + + can_use, reason = _can_use_flex_flash_attention( + subgraph, mask_graph, num_score_mod_placeholders + ) + + if not can_use: + raise RuntimeError( + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" + ) + + return True + + +def create_flex_flash_attention_kernel( + query: TensorBox, + key: TensorBox, + value: TensorBox, + block_mask: tuple[Any, ...], + scale: float, + kernel_options: dict[str, Any], + subgraph_buffer: SubgraphResults, + mask_graph_buffer: SubgraphResults, + score_mod_other_buffers: list[TensorBox], + mask_mod_other_buffers: list[TensorBox], + kv_num_blocks: TensorBox | None, + kv_indices: TensorBox | None, + full_kv_num_blocks: TensorBox | None, + full_kv_indices: TensorBox | None, + mask_graph: Subgraph, + subgraph: Subgraph | None = None, +) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]: + """Create a flex flash attention kernel using CuteDSL template.""" + if not ensure_flash_available(): + raise RuntimeError("CUTE flash attention not available") + + # Get dimensions + batch_size, num_heads, seq_len_q, head_dim = query.get_size() + v_head_dim = value.get_size()[-1] + device = query.get_device() + dtype = query.get_dtype() + assert device is not None, "Device must be specified" + + # Match stride pattern from query tensor + q_strides = query.get_stride() + out_size = [batch_size, num_heads, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, q_strides) + + output = empty_strided( + size=out_size, + stride=out_strides, + dtype=dtype, + device=device, + ) + + lse = empty_strided( + size=[batch_size, num_heads, seq_len_q], + stride=None, # LSE can be contiguous + dtype=torch.float32, # LSE is always fp32 + device=device, + ) + + # Create layout for primary output + output_layout = FixedLayout( + device=device, + dtype=dtype, + size=[batch_size, num_heads, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in output.get_stride()], + ) + + # Used to check if we can skip block sparse impl + mask_graph_is_trivial = is_trivial_mask_graph(mask_graph.graph_module) + + needs_block_mask = not mask_graph_is_trivial + has_full_blocks = full_kv_num_blocks is not None + + choices: list[Any] = [] + assert flash_attention_cutedsl_template is not None + + input_nodes = [query, key, value, lse] + if has_full_blocks: + input_nodes.extend( + [kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices] + ) + + if needs_block_mask and not has_full_blocks: + raise NotImplementedError( + "Flash attention with block mask but without full blocks is not supported yet" + ) + + error = flash_attention_cutedsl_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=output_layout, + mutated_inputs=[lse], + subgraphs=[subgraph_buffer, mask_graph_buffer], + SM_SCALE=scale, + NEEDS_BLOCK_MASK=needs_block_mask, + ) + + for choice in choices: + wrap_choice_render_with_cutedsl_indexer(choice) + + if error or not choices: + # Fallback to original implementation + raise RuntimeError(f"CuteDSL template failed: {error}") + + # No autotune for now + template_output = choices[0].output_node() + + return (template_output, lse) + + +def _can_use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, +) -> tuple[bool, str]: + if not ensure_flash_available(): + return False, "CUTE flash attention is not available" + + if not is_trivial_score_graph(fw_subgraph.graph_module): + return ( + False, + "NYI: Flex Flash Attention doesn't support score_mods in bwds yet.", + ) + + if not is_trivial_mask_graph(mask_graph.graph_module): + return False, "NYI: Flex Flash Attention doesn't support block_sparsity yet." + + return True, "" + + +def _use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], +) -> bool: + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False + + can_use, reason = _can_use_flex_flash_attention_backward( + fw_subgraph, + mask_graph, + ) + + if not can_use: + raise RuntimeError( + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" + ) + + return True + + +def create_flex_flash_attention_backward_kernel( + query: TensorBox, + key: TensorBox, + value: TensorBox, + out: TensorBox, + logsumexp: TensorBox, + grad_out: TensorBox, + scale: float, + kernel_options: dict[str, Any], + # TODO: will be needed + # grad_logsumexp, + # fw_graph: SubgraphResults, + # joint_graph: SubgraphResults, + # mask_graph: SubgraphResults, + # score_mod_other_buffers: list[TensorBox], + # mask_mod_other_buffers: list[TensorBox], + # kv_num_blocks: TensorBox | None, + # kv_indices: TensorBox | None, + # full_kv_num_blocks: TensorBox | None, + # full_kv_indices: TensorBox | None, +) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox, TensorBox, tuple]: + """Create a CuteDSL flash attention backward kernel for the default mod path.""" + if not ensure_flash_available(): + raise RuntimeError("CUTE flash attention not available") + + batch_size, num_heads, seq_len_q, head_dim = query.get_size() + v_head_dim = value.get_size()[-1] + device = query.get_device() + dtype = query.get_dtype() + assert device is not None + + grad_query_strides = infer_dense_strides( + [batch_size, num_heads, seq_len_q, head_dim], query.get_stride() + ) + grad_query = empty_strided( + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=grad_query_strides, + dtype=dtype, + device=device, + ) + + grad_key_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], head_dim], key.get_stride() + ) + grad_key = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], head_dim], + stride=grad_key_strides, + dtype=dtype, + device=device, + ) + + grad_value_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], v_head_dim], value.get_stride() + ) + grad_value = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], v_head_dim], + stride=grad_value_strides, + dtype=dtype, + device=device, + ) + + output_layout = FixedLayout( + device=device, + dtype=dtype, + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=[sympy.sympify(s) for s in grad_query.get_stride()], + ) + + choices: list[Any] = [] + + input_nodes = [ + query, + key, + value, + out, + grad_out, + logsumexp, + grad_key, + grad_value, + ] + + error = flash_attention_backward_cutedsl_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=output_layout, + mutated_inputs=[grad_key, grad_value], + SM_SCALE=scale, + ) + + for choice in choices: + wrap_choice_render_with_cutedsl_indexer(choice) + + if error or not choices: + raise RuntimeError(f"CuteDSL template failed: {error}") + + template_output = choices[0].output_node() + + return (template_output, grad_key, grad_value, tuple()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/common.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/common.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..f95beb14612924cfe2877710a4fe99c2e6c15084 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/common.py.jinja @@ -0,0 +1,204 @@ + + +# Common Imports +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( + desc_k, + [kv_base_offset, 0], + ) + {%- else %} + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + {%- endif %} + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_base_offset, 0], + ) + {%- else %} + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..2831ba6af5b60ef469d122a4886dbed9b557ede3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja @@ -0,0 +1,28 @@ +{{def_kernel("Q", "K", "V", "OUT", "D_OUT", "LSE", "DK", "DV")}} + from flash_attn.cute.interface import _flash_attn_bwd + + q_transposed = Q.transpose(1, 2) + k_transposed = K.transpose(1, 2) + v_transposed = V.transpose(1, 2) + out_transposed = OUT.transpose(1, 2) + d_out_transposed = D_OUT.transpose(1, 2) + + dq_transposed, dk_transposed, dv_transposed = _flash_attn_bwd( + q_transposed, + k_transposed, + v_transposed, + out_transposed, + d_out_transposed, + LSE, + softmax_scale={{SM_SCALE}}, + ) + + dq = dq_transposed.transpose(1, 2) + dk = dk_transposed.transpose(1, 2) + dv = dv_transposed.transpose(1, 2) + + dq_out = {{get_output()}} + {# TODO: add out support to flash #} + dq_out.copy_(dq) + DK.copy_(dk) + DV.copy_(dv) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..b92ea6c14a33fe11bb0c9bd485ca15be60317ded --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja @@ -0,0 +1,215 @@ +{{def_kernel("Q", "K", "V", "LSE", "MAX", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN, QK_HEAD_DIM], + strides=[stride_qm, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + + desc_k = tl.make_tensor_descriptor( + base=K, + shape=[KV_LEN, QK_HEAD_DIM], + strides=[stride_kn, 1], + block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], + ) + + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN, V_HEAD_DIM], + strides=[stride_vn, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask", val_shape=("BLOCK_M", "V_HEAD_DIM_ROUNDED"))}} + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/utilities.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/utilities.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..0c40b43277f8ae2da748487803758ff46c338ced --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/utilities.py.jinja @@ -0,0 +1,59 @@ + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..989f297c5f80f4053cbc54f6299181d4722efdb2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja @@ -0,0 +1,333 @@ +import functools +from torch._inductor.runtime.runtime_utils import ceildiv +from cutlass.utils import TensorMapUpdateMode +{{gen_defines()}} +# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- +from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( + GroupedGemmKernel, +) + + +# Note about caching: +# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor +# maintains its own local caching system. At this stage, all compile-time +# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel +# name itself ({{kernel_name}}) are permanently baked into the file, so they +# do not need to be included in any cache key. +# +# The caching mechanism is split into two levels: +# +# 1. prep_cache +# Caches the compiled executor for build_group_ptrs_from_bases(). This +# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, +# and can therefore be safely reused across runs with different group +# partitioning (`offs`). +# +# 2. gemm_cache +# Caches the compiled Grouped GEMM executor. Its key extends the prep +# cache key with hardware- and grid-specific parameters: +# (prep_cache_key, max_active_clusters, total_num_clusters). +# This is necessary because different `offs` tensors can change the +# per-group problem sizes and thus alter `total_num_clusters`, which in +# turn changes the grid shape and persistent scheduler configuration. +# Kernels compiled for one grid cannot be safely reused for another. +# +# +# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, +# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, +# despite depending only on the GPU type. We cache this function to mitigate +# redundant recompiles even when shape/stride/dtype cache misses force kernel +# regeneration. A follow-up study will investigate the root cause. + +prep_cache = {} +gemm_cache = {} + + +@functools.lru_cache +def get_hardware_info(): + hw = cutlass.utils.HardwareInfo() + sm_count = hw.get_max_active_clusters(1) + max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) + + return (sm_count, max_active_clusters) + + +def get_prep_cache_key(input_a, input_b, output): + """ + Returns a tuple key for caching the preprocessing kernel executor based on kernel name, + shapes, strides, and dtypes of input/output tensors. + """ + return ( + tuple(input_a.shape), + tuple(input_a.stride()), + input_a.dtype, + tuple(input_b.shape), + tuple(input_b.stride()), + input_b.dtype, + tuple(output.shape), + tuple(output.stride()), + output.dtype, + ) + + +def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): + """ + Returns a tuple key for caching the gemm kernel executor by extending the + prep cache key with hardware- and grid-specific parameters. + """ + return ( + prep_cache_key, + max_active_clusters, + total_num_clusters, + ) + + +@cute.kernel +def build_group_ptrs_from_bases_kernel( + base_A_u64: cutlass.Int64, # device addr of input_a (bytes) + base_B_u64: cutlass.Int64, # device addr of input_b (bytes) + base_C_u64: cutlass.Int64, # device addr of Output (bytes) + offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Int32, # bytes + # -------- STRIDES (in ELEMENTS) -------- + stride_A_m_elems: cutlass.Constexpr, # A.stride(0) + stride_A_k_elems: cutlass.Constexpr, # A.stride(1) + stride_B0_elems: cutlass.Constexpr, # B.stride(0) + stride_Bk_elems: cutlass.Constexpr, # B.stride(1) + stride_Bn_elems: cutlass.Constexpr, # B.stride(2) + stride_C_m_elems: cutlass.Constexpr, # C.stride(0) + stride_C_n_elems: cutlass.Constexpr, # C.stride(1) + # -------- OUTPUTS -------- + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) + out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) + out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] +): + tidx, _, _ = cute.arch.thread_idx() + g = tidx + + m_beg_i32 = 0 + if g > 0: + m_beg_i32 = offs[g - 1] + m_end_i32 = offs[g] + m_g_i32 = m_end_i32 - m_beg_i32 + + a_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) + ) + c_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) + ) + b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) + + # ---- pointers ---- + out_ptrs[g, 0] = base_A_u64 + a_byte_off + out_ptrs[g, 1] = base_B_u64 + b_byte_off + out_ptrs[g, 2] = base_C_u64 + c_byte_off + + # ---- (m, n, k, 1) ---- + out_problem[g, 0] = m_g_i32 + out_problem[g, 1] = N + out_problem[g, 2] = K + out_problem[g, 3] = cutlass.Int32(1) + + # ---- strides ---- + out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) + out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) + out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) + out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) + out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) + out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) + + +@cute.jit +def launch_build_group_ptrs_from_bases( + base_A_u64: cutlass.Int64, + base_B_u64: cutlass.Int64, + base_C_u64: cutlass.Int64, + offs: cute.Tensor, + G: cutlass.Constexpr, + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Constexpr, + stride_A_m_elems: cutlass.Constexpr, + stride_A_k_elems: cutlass.Constexpr, + stride_B0_elems: cutlass.Constexpr, + stride_Bk_elems: cutlass.Constexpr, + stride_Bn_elems: cutlass.Constexpr, + stride_C_m_elems: cutlass.Constexpr, + stride_C_n_elems: cutlass.Constexpr, + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 + out_problem: cute.Tensor, # [G,4] cutlass.Int32 + out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 + stream: cuda.CUstream, +): + build_group_ptrs_from_bases_kernel( + base_A_u64, + base_B_u64, + base_C_u64, + offs, + K, + N, + sizeof_element, + stride_A_m_elems, + stride_A_k_elems, + stride_B0_elems, + stride_Bk_elems, + stride_Bn_elems, + stride_C_m_elems, + stride_C_n_elems, + out_ptrs, + out_problem, + out_strides_abc, + ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) + + +{{def_kernel("input_a", "input_b", "input_a_offs")}} + stream = cuda.CUstream(stream) + + input_b = input_b.transpose(1, 2) + + sumM, K = input_a.shape + G, N, Kb = input_b.shape + + dev = input_a.device + + base_A_u64 = int(input_a.data_ptr()) + base_B_u64 = int(input_b.data_ptr()) + base_C_u64 = int({{get_output()}}.data_ptr()) + + ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) + probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) + strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) + ptrs = from_dlpack(ptrs_t) + probs = from_dlpack(probs_t) + strides = from_dlpack(strides_t) + + prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) + prep_executor = prep_cache.get(prep_cache_key) + + if prep_executor is None: + sizeof_element = int(input_a.element_size()) + sA_m, sA_k = map(int, input_a.stride()) + sB_0, sB_n, sB_k = map(int, input_b.stride()) + sC_m, sC_n = map(int, {{get_output()}}.stride()) + + prep_executor = cute.compile( + launch_build_group_ptrs_from_bases, + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + G=int(G), + K=int(K), + N=int(N), + sizeof_element=sizeof_element, + stride_A_m_elems=sA_m, + stride_A_k_elems=sA_k, + stride_B0_elems=sB_0, + stride_Bk_elems=sB_k, + stride_Bn_elems=sB_n, + stride_C_m_elems=sC_m, + stride_C_n_elems=sC_n, + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + prep_cache[prep_cache_key] = prep_executor + + prep_executor( + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + # --- Tensormap workspace per SM --- + num_tensormap_buffers, max_active_clusters = get_hardware_info() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) + tensormap_workspace = from_dlpack(tensormap_workspace_t) + + # --- Total clusters --- + def compute_total_num_clusters( + problem_sizes_mnkl, + cluster_tile_shape_mn, + ): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + ): + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) + ) + + total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) + + gemm_cache_key = get_gemm_cache_key( + prep_cache_key, max_active_clusters, total_num_clusters + ) + gemm_executor = gemm_cache.get(gemm_cache_key) + + if gemm_executor is None: + grouped_gemm = GroupedGemmKernel( + acc_dtype=ACC_DTYPE, + use_2cta_instrs=USE_2_CTA, + mma_tiler_mn=(TILE_M, TILE_N), + cluster_shape_mn=(CLUSTER_M, CLUSTER_N), + tensormap_update_mode=TENSORMAP_UPDATE_MODE, + ) + + gemm_executor = cute.compile( + grouped_gemm, + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + G, + probs, + strides, + ptrs, + total_num_clusters, + tensormap_workspace, + max_active_clusters, + stream, + ) + + gemm_cache[gemm_cache_key] = gemm_executor + + gemm_executor( + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + probs, + strides, + ptrs, + tensormap_workspace, + stream, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..34ff2d69793c004b050cfbbd939218a7ed6a255f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja @@ -0,0 +1,107 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + start_pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + + # Note: We require TMA_EXPERIMENTAL_API == False, which + # we will check before invoking this template. + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + + # tile_id_c is used in the epilogue to break the dependency between + # the prologue and the epilogue + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_M * grid_n + + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE + ): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_K + a = tl.load_tensor_descriptor( + a_desc, + [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am], + ) + b = tl.load_tensor_descriptor( + b_desc, + [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k], + ) + accumulator += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_cm = pid_m * BLOCK_M + offs_cn = pid_n * BLOCK_N + {%- if EPILOGUE_SUBTILE %} + tl.static_assert(BLOCK_N % 2 == 0) + acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + {{store_output( + ("offs_cm", "offs_cn"), + "acc0", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + offs_cn2 = offs_cn + BLOCK_N // 2 + {{store_output( + ("offs_cm", "offs_cn2"), + "acc1", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + {%- else %} + {{store_output( + ("offs_cm", "offs_cn"), + "accumulator", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True + )}} + {%- endif %} + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + GROUP_M = min(grid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % GROUP_M) + pid_n = (tile_id % num_pid_in_group) // GROUP_M + return pid_m, pid_n diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..56ef18b7a91e3cea8fb49da3465082cc47162a09 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja @@ -0,0 +1,194 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALE_RECIPE_A == 1: # ScalingType.RowWise + stride_a_scale_m = 1 + else: + stride_a_scale_m = 0 + + if SCALE_RECIPE_B == 1: # ScalingType.RowWise + stride_b_scale_n = 1 + else: + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + {%- endif %} + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + {%- else %} + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + {%- endif %} + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A, + SCALE_RECIPE_B, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + {%- endif %} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A: tl.constexpr, + SCALE_RECIPE_B: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise) + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: # (ScalingType.TensorWise, ScalingType.TensorWise) + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..171340a2c92333c3e514f560183ac746c458b9ce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja @@ -0,0 +1,212 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + + am_blocks = tl.cdiv(M, TILE_SIZE_A) + ak_blocks = tl.cdiv(K, TILE_SIZE_A) + bn_blocks = tl.cdiv(N, TILE_SIZE_B) + bk_blocks = tl.cdiv(K, TILE_SIZE_B) + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_a_block = blockwise128x128_scaling( + pid_m, + a_scale, + ki, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AM, + MIN_BLOCK_TILE_AK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_a_block = blockwise1xTILESIZE_scaling( + pid_m, + a_scale, + ki, + M, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AK, + TILE_SIZE_A, + ) + {%- endif %} + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_b_block = blockwise128x128_scaling( + pid_n, + b_scale, + ki, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BN, + MIN_BLOCK_TILE_BK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_b_block = blockwise1xTILESIZE_scaling( + pid_n, + b_scale, + ki, + N, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BK, + TILE_SIZE_B, + ) + {%- endif %} + + a_scaled = a * scale_a_block + b_scaled = b * scale_b_block + accumulator = tl.dot(a_scaled, b_scaled.T, accumulator) + + if ki == k_tiles - 1: + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + + # inductor generates a suffix + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def blockwise1xTILESIZE_scaling( + pid, + scale, + ki, + lhs_size, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, + TILE_SIZE: tl.constexpr, +): + row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs) + col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + (BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + BLOCK_lhs, + ((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K + ) + + return scale_expanded + + +@triton.jit +def blockwise128x128_scaling( + pid, + scale, + ki, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_lhs: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, +): + row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128) + col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + ((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + ((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs, + ((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K + ) + + return scale_expanded diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..2da348f3e767cfbb91350ccb3831c9bf07b07528 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm.py.jinja @@ -0,0 +1,72 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..42b99c70d5cbd5394c00662793b212661c48e48b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja @@ -0,0 +1,71 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..38fe092c257803f4676092af83e40e3eeb55f8c7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja @@ -0,0 +1,129 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + start_pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + width = GROUP_M * grid_n + rk_for_mask = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + global_size=[M, K] if A_ROW_MAJOR else [K, M], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + global_size=[K, N] if B_ROW_MAJOR else [N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + {%- endif %} + + pid_m = 0 + pid_n = 0 + rm = 0 + rn = 0 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + rk = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + A.dtype.element_ty, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + B.dtype.element_ty, + ) + {%- else %} + a = tl.load_tensor_descriptor( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + ) + b = tl.load_tensor_descriptor( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + ) + {%- endif %} + acc += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + if ki == k_tiles - 1: + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output(("rm", "rn"), "acc", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"), block_indexing=True)}} + {%- endif %} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..becac750003df0240b2708840bbc9fa19599ff2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py @@ -0,0 +1,2372 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import functools +from typing import List, Type, Union +from inspect import isclass + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.torch as cutlass_torch + +""" +A grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of grouped GEMM using a TMA plus Blackwell SM100 TensorCore +warp-specialized persistent kernel. +The grouped GEMM workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices +in global memory are passed to the kernel in an array (also held in global memory). Similarly, problem shapes and +strides are also stored in arrays in GMEM. + +This differs from "Batched Array" GEMM since the size of each GEMM problem in the grouped GEMM concept may be distinct. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM + +The above example command makes 4 groups of different m, n, k sizes. The Blackwell tcgen05 MMA tile shape +is specified as (128, 64) and the cluster shape is (1,1). The input, mma accumulator and output data type +are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + +There are some constrains for this example. Besides the constrains from the Balckwell dense GEMM persistent example, +there are also the following constrains: +* Only fp16 and bf16 data types are supported as inputs. +* Output data types could be fp16, bf16 or fp32. +* The contiguous dimension of each tensor must be at least 16 bytes aligned. +* The l mode(aka, batch size) for each group must be 1. +* The majorness for A, B and C must be the same across all groups. +""" + + +class GroupedGemmKernel: + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + tensormap_update_mode: utils.TensorMapUpdateMode = utils.TensorMapUpdateMode.SMEM, + ): + """Initializes the configuration for a Blackwell grouped GEMM kernel. + + Besides configurations for dense persistent GEMM, there is an extra config specific to grouped GEMM: + + Tensormap Update Mode: + - tensormap_update_mode: Specifies whether the tensormap is + updated in global memory(GMEM) or shared memory(SMEM). + The 2 modes are functionally equivalent and the difference are: + - We buffer 3 tensormaps in SMEM for A, B, and C tensors (each TMA descriptor takes 128B) when TMA updates performed on SMEM. + - Performance varies between modes depending on problem size; optimal choice differs across workloads. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param mma_tiler_mn: tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: tuple[int, int] + :param cluster_shape_mn: tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: tuple[int, int] + :param tensormap_update_mode: Mode for updating the tensormap (GMEM or SMEM), defaults to SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode, optional + """ + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.tensormap_update_mode = tensormap_update_mode + # Delegate tensormap ab initialization to MMA warp when SMEM mode is used for better latency hiding + self.delegate_tensormap_ab_init = ( + tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ) + + self.num_mcast_ctas_a = 1 + self.num_mcast_ctas_b = 1 + self.is_a_mcast = False + self.is_b_mcast = False + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier for epilog sync, tmem ptr sync and tensormap update sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + # Barrier used by MMA/TMA warps to signal A/B tensormap initialization completion + self.tensormap_ab_init_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * (len(self.epilog_warp_id) + 1), + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.num_tma_load_bytes = 0 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + Most of the implementation follows standard dense GEMM patterns, + with the key difference being additional consideration for SMEM + buffer needed for tensormap updates. + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cluster_tile_shape_mnk = tuple( + x * y for x, y in zip(self.cta_tile_shape_mnk, (*self.cluster_shape_mn, 1)) + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_epi_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + ) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_epi_stage, + ) + + mbar_smem_bytes = self._get_mbar_smem_bytes( + num_acc_stage=self.num_acc_stage, + num_ab_stage=self.num_ab_stage, + num_epi_stage=self.num_epi_stage, + ) + tensormap_smem_bytes = self._get_tensormap_smem_bytes( + self.tensormap_update_mode + ) + if ( + mbar_smem_bytes + + tensormap_smem_bytes + + GroupedGemmKernel.tensor_memory_management_bytes + > self.reserved_smem_bytes + ): + raise ValueError( + f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the " + f"reserved smem bytes {self.reserved_smem_bytes}" + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + initial_a: cute.Tensor, + initial_b: cute.Tensor, + initial_c: cute.Tensor, + group_count: cutlass.Constexpr[int], + problem_shape_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + total_num_clusters: cutlass.Constexpr[int], + tensormap_cute_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr[int], + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided + by different tensors in global memory. The "initial" tensors only carry data type and + majorness information. + + :param initial_a: Initial tensor A, used for data type and majorness information. + :type initial_a: cute.Tensor + :param initial_b: Initial tensor B, used for data type and majorness information. + :type initial_b: cute.Tensor + :param initial_c: Initial tensor C, used for data type and majorness information. + :type initial_c: cute.Tensor + :param group_count: The number of GEMM groups. + :type group_count: cutlass.Constexpr[int] + :param problem_shape_mnkl: Tensor containing the (M, N, K, L) shape for each group. + :type problem_shape_mnkl: cute.Tensor + :param strides_abc: Tensor containing the strides for A, B, and C for each group. + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing the base addresses for A, B, and C for each group. + :type tensor_address_abc: cute.Tensor + :param total_num_clusters: Total number of clusters needed for all groups. + :type total_num_clusters: cutlass.Constexpr[int] + :param tensormap_cute_tensor: Tensor for storing tensormaps. + :type tensormap_cute_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + :param stream: CUDA stream for asynchronous execution. + :type stream: cuda.CUstream + :raises TypeError: If A and B data types do not match. + """ + self.a_dtype = initial_a.element_type + self.b_dtype = initial_b.element_type + self.c_dtype = initial_c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(initial_c) + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + initial_a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + initial_b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + initial_c, + epi_smem_layout, + self.epi_tile, + ) + + self.tile_sched_params, grid = self._compute_grid( + total_num_clusters, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + self.size_tensormap_in_i64 = ( + 0 + if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM + else GroupedGemmKernel.num_tensormaps + * GroupedGemmKernel.bytes_per_tensormap + // 8 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, self.size_tensormap_in_i64 + ] + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.epi_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + group_count, + problem_shape_mnkl, + strides_abc, + tensor_address_abc, + tensormap_cute_tensor, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + group_count: cutlass.Constexpr[int], + problem_sizes_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + ptrs_abc: cute.Tensor, + tensormaps: cute.Tensor, + ): + """ + GPU device kernel performing the grouped GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coord inside cluster + bid = cute.arch.block_idx() + mma_tile_coord_v = bid[0] % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: tensormap buffer, a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tensormap_a_smem_ptr = None + tensormap_b_smem_ptr = None + tensormap_c_smem_ptr = None + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_c_smem_ptr = ( + tensormap_b_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + ab_full_mbar_ptr = storage.ab_full_mbar_ptr.data_ptr() + ab_empty_mbar_ptr = storage.ab_empty_mbar_ptr.data_ptr() + acc_full_mbar_ptr = storage.acc_full_mbar_ptr.data_ptr() + acc_empty_mbar_ptr = storage.acc_empty_mbar_ptr.data_ptr() + + # init barrier for loading A, B with TMA + if warp_idx == self.epilog_warp_id[0]: + for k_stage in range(self.num_ab_stage): + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + with cute.arch.elect_one(): + cute.arch.mbarrier_init(ab_full_mbar_ptr + k_stage, 1) + cute.arch.mbarrier_init( + ab_empty_mbar_ptr + k_stage, num_tma_producer + ) + # Accumulator barrier init + if warp_idx == self.mma_warp_id: + for acc_stage in range(self.num_acc_stage): + with cute.arch.elect_one(): + cute.arch.mbarrier_init(acc_full_mbar_ptr + acc_stage, 1) + cute.arch.mbarrier_init( + acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4 + ) + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full and empty + # + a_full_mcast_mask = None + b_full_mcast_mask = None + ab_empty_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + ab_empty_mcast_mask = a_full_mcast_mask | b_full_mcast_mask + acc_full_mcast_mask = None + if cutlass.const_expr(use_2cta_instrs): + acc_full_mcast_mask = cute.make_layout_image_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mode=0 + ) + block_in_cluster_coord_vmnk_peer = ( + block_in_cluster_coord_vmnk[0] ^ 1, + *block_in_cluster_coord_vmnk[1:], + ) + a_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 + ) + b_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 + ) + ab_empty_mcast_mask = ( + a_full_mcast_mask_peer + | b_full_mcast_mask_peer + | cutlass.Int16( + 0 if ab_empty_mcast_mask is None else ab_empty_mcast_mask + ) + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for load A, B with TMA + # + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Get tensormap buffer address + # + grid_dim = cute.arch.grid_dim() + tensormap_workspace_idx = ( + bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0] + ) + + tensormap_manager = utils.TensorMapManager( + self.tensormap_update_mode, GroupedGemmKernel.bytes_per_tensormap + ) + tensormap_a_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 0, None)].iterator + ) + tensormap_b_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 1, None)].iterator + ) + tensormap_c_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 2, None)].iterator + ) + # Setup tensormap initialization pointer based on the mode + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_a_init_ptr = tensormap_a_smem_ptr + tensormap_b_init_ptr = tensormap_b_smem_ptr + tensormap_c_init_ptr = tensormap_c_smem_ptr + else: + tensormap_a_init_ptr = tensormap_a_ptr + tensormap_b_init_ptr = tensormap_b_ptr + tensormap_c_init_ptr = tensormap_c_ptr + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # Initialize tensormaps for A, B + if cutlass.const_expr(self.delegate_tensormap_ab_init == False): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.tma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.tma_warp_id + ) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + tensormap_init_done = cutlass.Boolean(False) + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + # skip tensormap update if we're working on the same group + if is_group_changed: + real_tensor_a = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.a_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 0, # 0 for tensor A + ) + real_tensor_b = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.b_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 1, # 1 for tensor B + ) + # wait tensormap initialization complete before update + if tensormap_init_done == False: + if cutlass.const_expr(self.delegate_tensormap_ab_init): + self.tensormap_ab_init_barrier.arrive_and_wait() + tensormap_manager.fence_tensormap_initialization() + tensormap_init_done = True + + tensormap_manager.update_tensormap( + (real_tensor_a, real_tensor_b), + (tma_atom_a, tma_atom_b), + (tensormap_a_ptr, tensormap_b_ptr), + self.tma_warp_id, + (tensormap_a_smem_ptr, tensormap_b_smem_ptr), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + tma_wr_k_tile = cutlass.Int32(0) + smem_wr_buffer = (num_prev_k_blk + tma_wr_k_tile) % self.num_ab_stage + tma_wr_ab_empty_phase = ( + num_prev_k_blk + tma_wr_k_tile + ) // self.num_ab_stage % 2 ^ 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer, + tma_wr_ab_empty_phase, + ) + # ensure the update to tensormap has completed before using it + if is_group_changed: + tensormap_manager.fence_tensormap_update(tensormap_a_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_ptr) + # + # Tma load loop + # + for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): + tma_wr_k_tile_next = tma_wr_k_tile + 1 + smem_wr_buffer_next = ( + num_prev_k_blk + tma_wr_k_tile_next + ) % self.num_ab_stage + tma_wr_ab_empty_phase_next = ( + tma_wr_ab_empty_phase ^ 1 + if smem_wr_buffer_next == 0 + else tma_wr_ab_empty_phase + ) + + smem_full_mbar_ptr = ab_full_mbar_ptr + smem_wr_buffer + + # Wait for AB buffer empty + if peek_ab_empty_status == 0: + cute.arch.mbarrier_wait( + ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase + ) + + # Arrive AB buffer and expect full transaction bytes + if is_leader_cta: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + smem_full_mbar_ptr, self.num_tma_load_bytes + ) + + # Load A/B with TMA + cute.copy( + tma_atom_a, + tAgA_slice[(None, tma_wr_k_tile)], + tAsA[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=a_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, tma_wr_k_tile)], + tBsB[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=b_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_ptr, + cute.AddressSpace.generic, + ), + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile_next < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer_next, + tma_wr_ab_empty_phase_next, + ) + + tma_wr_k_tile = tma_wr_k_tile_next + smem_wr_buffer = smem_wr_buffer_next + tma_wr_ab_empty_phase = tma_wr_ab_empty_phase_next + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # Bar sync for retrieve tmem ptr from shared mem + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + # MMA warp is only interested in number of tiles along K dimension + ( + cur_k_tile_cnt, + cur_group_idx, + ) = group_gemm_ts_helper.search_cluster_tile_count_k( + cur_tile_coord, + problem_sizes_mnkl, + ) + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_buf_idx)] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer full for k_tile = 0 + mma_rd_k_tile = cutlass.Int32(0) + smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage + if is_leader_cta: + need_check_rd_buffer_full = ( + mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta + ) + mma_rd_ab_full_phase = ( + (num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2 + ) + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer, + mma_rd_ab_full_phase, + ) + + # + # Wait for accumulator buffer empty + # + acc_empty_phase = ( + tile_sched.num_tiles_executed // self.num_acc_stage % 2 ^ 1 + ) + cute.arch.mbarrier_wait( + acc_empty_mbar_ptr + acc_buf_idx, acc_empty_phase + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(cur_k_tile_cnt): + mma_rd_k_tile_next = cutlass.Int32(k_tile + 1) + smem_rd_buffer_next = ( + num_prev_k_blk + mma_rd_k_tile_next + ) % self.num_ab_stage + mma_rd_ab_full_phase_next = ( + mma_rd_ab_full_phase ^ 1 + if smem_rd_buffer_next == 0 + else mma_rd_ab_full_phase + ) + # Wait for AB buffer full + if peek_ab_full_status == 0: + cute.arch.mbarrier_wait( + ab_full_mbar_ptr + smem_rd_buffer, mma_rd_ab_full_phase + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, smem_rd_buffer) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + with cute.arch.elect_one(): + tcgen05.commit( + ab_empty_mbar_ptr + smem_rd_buffer, + ab_empty_mcast_mask, + self.cta_group, + ) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + need_check_rd_buffer_full = ( + mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta + ) + + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer_next, + mma_rd_ab_full_phase_next, + ) + + mma_rd_k_tile = mma_rd_k_tile_next + smem_rd_buffer = smem_rd_buffer_next + mma_rd_ab_full_phase = mma_rd_ab_full_phase_next + + # + # Async arrive accumulator buffer full + # + with cute.arch.elect_one(): + tcgen05.commit( + acc_full_mbar_ptr + acc_buf_idx, + acc_full_mcast_mask, + self.cta_group, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # initialize tensormap A, B for TMA warp + if cutlass.const_expr(self.delegate_tensormap_ab_init): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.epilog_warp_id[0] + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.epilog_warp_id[0] + ) + # signal tensormap initialization has finished + self.tensormap_ab_init_barrier.arrive_and_wait() + # initialize tensorap for C + tensormap_manager.init_tensormap_from_atom( + tma_atom_c, + tensormap_c_init_ptr, + self.epilog_warp_id[0], + ) + # Alloc tensor memory buffer + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + epi_tidx = tidx + # + # Partition for epilogue + # + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # wait tensormap initialization complete before update + tensormap_manager.fence_tensormap_initialization() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + if is_group_changed: + # construct tensor C based on real address, shape and stride information + real_tensor_c = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.c_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 2, # 2 for tensor C + ) + tensormap_manager.update_tensormap( + ((real_tensor_c),), + ((tma_atom_c),), + ((tensormap_c_ptr),), + self.epilog_warp_id[0], + (tensormap_c_smem_ptr,), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + total_k_tile_cnt += cur_k_tile_cnt + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_buf_idx)] + + # + # Wait for accumulator buffer full + # + acc_full_phase = tile_sched.num_tiles_executed // self.num_acc_stage % 2 + cute.arch.mbarrier_wait(acc_full_mbar_ptr + acc_buf_idx, acc_full_phase) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + # ensure the update to tensormap has completed before using it + if is_group_changed: + if warp_idx == self.epilog_warp_id[0]: + tensormap_manager.fence_tensormap_update(tensormap_c_ptr) + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to output type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + tRS_rC.store(acc_vec.to(self.c_dtype)) + # + # Store C to shared memory + # + epi_buffer = (num_prev_subtiles + subtile_idx) % self.num_epi_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, epi_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + # + # store C to global memory with TMA + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, epi_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_c_ptr, + cute.AddressSpace.generic, + ), + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group( + self.num_epi_stage - 1, read=True + ) + self.epilog_sync_barrier.arrive_and_wait() + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive( + acc_empty_mbar_ptr + acc_buf_idx, + cta_rank_in_cluster // 2 * 2 if use_2cta_instrs else None, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + + # + # Wait a/b buffer empty + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.mbarrier_wait( + (ab_empty_mbar_ptr + ((total_k_tile_cnt - 1) % self.num_ab_stage)), + (((total_k_tile_cnt - 1) // self.num_ab_stage) % 2), + ) + + @cute.jit + def make_tensor_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_index: int, + ): + """Extract stride and tensor address for a given group and construct a global tensor. + + This function is used within the kernel to dynamically create a CUTE tensor + representing A, B, or C for the current group being processed, using the + group-specific address, shape, and stride information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param strides_abc: Tensor containing strides for A, B, C for all groups. Layout: (group_count, 3, 2). + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing global memory addresses for A, B, C for all groups. Layout: (group_count, 3). + :type tensor_address_abc: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for A, 1 for B, 2 for C. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (A, B, or C) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_abc[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] + strides_tensor_reg = cute.make_rmem_tensor( + cute.make_layout(2), + strides_abc.element_type, + ) + cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg) + stride_mn = strides_tensor_reg[0] + stride_k = strides_tensor_reg[1] + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(tensor_index == 0): # tensor A + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)), + ) + elif cutlass.const_expr(tensor_index == 1): # tensor B + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)), + ) + else: # tensor C + m = problem_shape_mnk[0] + n = problem_shape_mnk[1] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)), + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load(t2r) + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tma_atom_c: cute.CopyAtom, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to partition + shared memory (source) and global memory (destination) for TMA store version. + + :param tma_atom_c: The TMA copy atom configured for storing tensor C. + :type tma_atom_c: cute.CopyAtom + :param gC_mnl: The global memory tensor C. + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler defining the granularity of the operation. + :type epi_tile: cute.Tile + :param sC: The shared memory epilogue buffer tensor. + :type sC: cute.Tensor + :return: A tuple containing: + - tma_atom_c: The input TMA copy atom (passed through). + - bSG_sC: The source shared memory tensor partitioned for the TMA operation. + - tCgC: The destination global memory tensor partitioned for the TMA operation. + :rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int, int]: + """Computes the number of stages for accumulator, A/B operands, and epilogue based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (accumulator stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default accumulator and epilogue stages + num_acc_stage = 2 + num_epi_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and Epilogue + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # stage=1 + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # stage=1 + ) + epi_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, # stage=1 + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + epi_bytes_per_stage = cute.size_in_bytes(c_dtype, epi_smem_layout_staged_one) + epi_bytes = epi_bytes_per_stage * num_epi_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial epilogue bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity // occupancy + - GroupedGemmKernel.reserved_smem_bytes + - epi_bytes + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + remaining_smem = ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (GroupedGemmKernel.reserved_smem_bytes + epi_bytes) + ) + num_epi_stage += remaining_smem // (occupancy * epi_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_epi_stage + + @staticmethod + def _compute_grid( + total_num_clusters: int, + cluster_shape_mn: tuple[int, int], + max_active_clusters: cutlass.Constexpr[int], + ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: + """Compute tile scheduler parameters and grid shape for grouped GEMM operations. + + :param total_num_clusters: Total number of clusters to process across all groups. + :type total_num_clusters: int + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: tuple[utils.PersistentTileSchedulerParams, tuple[int, ...]] + """ + # Create problem shape with M, N dimensions from cluster shape + # and L dimension representing the total number of clusters. + problem_shape_ntile_mnl = ( + cluster_shape_mn[0], + cluster_shape_mn[1], + cutlass.Int32(total_num_clusters), + ) + + tile_sched_params = utils.PersistentTileSchedulerParams( + problem_shape_ntile_mnl, (*cluster_shape_mn, 1) + ) + + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_mbar_smem_bytes(**kwargs_stages: int) -> int: + """Calculate shared memory consumption for memory barriers based on provided stages. + + Each stage requires 2 barriers, and each barrier consumes 8 bytes of shared memory. + The total consumption is the sum across all provided stages. This function calculates the total + shared memory needed for these barriers. + + :param kwargs_stages: Variable keyword arguments where each key is a stage name + (e.g., num_acc_stage, num_ab_stage) and each value is the + number of stages of that type. + :type kwargs_stages: int + :return: Total shared memory bytes required for all memory barriers. + :rtype: int + """ + num_barriers_per_stage = 2 + num_bytes_per_barrier = 8 + mbar_smem_consumption = sum( + [ + num_barriers_per_stage * num_bytes_per_barrier * stage + for stage in kwargs_stages.values() + ] + ) + return mbar_smem_consumption + + @staticmethod + def _get_tensormap_smem_bytes( + tensormap_update_mode: utils.TensorMapUpdateMode, + ) -> int: + """Get the SMEM consumption for the tensormap buffer based on the update mode. + + :param tensormap_update_mode: Specifies whether tensormaps are updated in GMEM or SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode + :return: The shared memory bytes required for the tensormap buffer. Returns 0 if mode is GMEM. + :rtype: int + :raises ValueError: If an invalid tensormap update mode is provided. + """ + if tensormap_update_mode == utils.TensorMapUpdateMode.GMEM: + return 0 + elif tensormap_update_mode == utils.TensorMapUpdateMode.SMEM: + return ( + GroupedGemmKernel.bytes_per_tensormap * GroupedGemmKernel.num_tensormaps + ) + else: + raise ValueError(f"Invalid tensormap update mode: {tensormap_update_mode}") + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param acc_stage: The stage of the accumulator tensor. + :type acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + return num_tmem_alloc_cols + + # Size of smem we reserved for mbarrier, tensor memory management and tensormap update + reserved_smem_bytes = 1024 + bytes_per_tensormap = 128 + num_tensormaps = 3 + # size of smem used for tensor memory management + tensor_memory_management_bytes = 12 + + +# Create tensor and return the pointer, tensor, and stride +def create_tensor_and_stride( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: type[cutlass.Numeric], + is_dynamic_layout: bool = True, + torch_tensor_cpu: torch.Tensor = None, +) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: + """Create GPU tensor from either a new or existing CPU tensor. + + :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. + :type torch_tensor_cpu: torch.Tensor, optional + """ + if torch_tensor_cpu is None: + # Create new CPU tensor + torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype) + + # Create GPU tensor from CPU tensor (new or existing) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 + ) + return ( + torch_tensor.data_ptr(), + torch_tensor, + cute_tensor, + torch_tensor_cpu, + torch_tensor.stride()[:-1], + ) + + +def create_tensors_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + torch_fp32_tensors_abc: List[List[torch.Tensor]] = None, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[tuple]], + List[List[torch.Tensor]], +]: + if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len( + problem_sizes_mnkl + ): + raise ValueError("torch_fp32_tensors_abc must have one entry per group") + + # Initialize lists to store tensors for all groups + new_torch_fp32_tensors_abc = ( + [] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc + ) + torch_tensors_abc = [] + cute_tensors_abc = [] + strides_abc = [] + ptrs_abc = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + # Get existing CPU tensors if available, otherwise None + existing_cpu_a = ( + torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None + ) + existing_cpu_b = ( + torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None + ) + existing_cpu_c = ( + torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None + ) + + # Create tensors (reusing CPU tensors if provided) + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + tensor_fp32_a, + stride_mk_a, + ) = create_tensor_and_stride( + l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a + ) + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + tensor_fp32_b, + stride_nk_b, + ) = create_tensor_and_stride( + l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b + ) + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + tensor_fp32_c, + stride_mn_c, + ) = create_tensor_and_stride( + l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c + ) + + # Only append to new_torch_fp32_tensors_abc if we created new CPU tensors + if torch_fp32_tensors_abc is None: + new_torch_fp32_tensors_abc.append( + [tensor_fp32_a, tensor_fp32_b, tensor_fp32_c] + ) + + ptrs_abc.append([ptr_a, ptr_b, ptr_c]) + torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + cute_tensors_abc.append( + ( + cute_tensor_a, + cute_tensor_b, + cute_tensor_c, + ) + ) + + return ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + new_torch_fp32_tensors_abc, + ) + + +def run( + num_groups: int, + problem_sizes_mnkl: tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + tensormap_update_mode: utils.TensorMapUpdateMode, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """Run grouped GEMM example with specified configurations. + + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print("Running Blackwell Grouped GEMM test with:") + print(f"{num_groups} groups") + for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): + print(f"Group {i}: {m}x{n}x{k}x{l}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Tensor map update mode: {tensormap_update_mode}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Skip unsupported types + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + }: + raise ValueError(f"Skip unsupported ab_dtype {ab_dtype}") + if c_dtype not in {cutlass.Float16, cutlass.BFloat16, cutlass.Float32}: + raise ValueError(f"Skip unsupported c_dtype {c_dtype}") + # Skip unsupported acc dtype + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: + raise ValueError(f"Skip unsupported acc_dtype {acc_dtype}") + # Skip invalid ab_dtype and acc_dtype combination + if ab_dtype == cutlass.BFloat16 and acc_dtype == cutlass.Float16: + raise ValueError("Skip invalid ab_dtype and acc_dtype combination") + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + raise ValueError(f"Skip invalid mma tiler M {mma_tiler_mn[0]}") + if mma_tiler_mn[1] not in range(32, 257, 32): + raise ValueError(f"Skip invalid mma tiler N {mma_tiler_mn[1]}") + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape_m need align with use_2cta_instrs config {cluster_shape_mn}" + ) + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + raise ValueError(f"Skip invalid cluster shape {cluster_shape_mn}") + + # Skip illegal problem shape for load/store alignment + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + raise ValueError("Skip invalid problem alignment") + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Create tensors for all groups using the new function + ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + torch_fp32_tensors_abc, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + + # Choose A, B, C with the smallest size to create initial tensormaps + key_size_a = lambda item: item[1][0] * item[1][2] + key_size_b = lambda item: item[1][1] * item[1][2] + key_size_c = lambda item: item[1][0] * item[1][1] + # Find the indices of the groups with the smallest tensor sizes + min_a_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_a) + min_b_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_b) + min_c_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_c) + initial_cute_tensors_abc = [ + cute_tensors_abc[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc[min_c_idx][2], # C with smallest (m, n) + ] + + hardware_info = utils.HardwareInfo() + sm_count = hardware_info.get_max_active_clusters(1) + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + # Prepare tensormap buffer for each SM + num_tensormap_buffers = sm_count + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + grouped_gemm = GroupedGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + tensormap_update_mode, + ) + + # layout (num_groups, 4):(4, 1) + ( + tensor_of_dim_size_mnkl, + tensor_of_dim_size_mnkl_torch, + ) = cutlass_torch.cute_tensor_like( + torch.tensor(problem_sizes_mnkl, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups, 3, 2):(6, 2, 1) + tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,3):(3, 1) + tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + def compute_total_num_clusters( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + cluster_tile_shape_mn: tuple[int, int], + ) -> int: + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + ) -> tuple[int, int]: + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + mma_tiler_mn, cluster_shape_mn, use_2cta_instrs + ) + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cluster_tile_shape_mn + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile grouped GEMM kernel + compiled_grouped_gemm = cute.compile( + grouped_gemm, + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + num_groups, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + total_num_clusters, + tensor_of_tensormap, + max_active_clusters, + current_stream, + ) + + if not skip_ref_check: + compiled_grouped_gemm( + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_tensormap, + current_stream, + ) + + # Compute reference result + for i, (a, b, c) in enumerate(torch_tensors_abc): + ref = torch.einsum( + "mkl,nkl->mnl", + a.cpu().to(dtype=torch.float32), + b.cpu().to(dtype=torch.float32), + ) + print(f"checking group {i}") + torch.testing.assert_close( + c.cpu(), + ref.to(cutlass_torch.dtype(c_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + # Reuse existing CPU tensors and create new GPU tensors from them + ( + ptrs_abc_workspace, + torch_tensors_abc_workspace, + cute_tensors_abc_workspace, + strides_abc_workspace, + _, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + torch_fp32_tensors_abc, + ) + + initial_cute_tensors_abc_workspace = [ + cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n) + ] + + # Create new tensors for this workspace + tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc_workspace, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensormap_workspace, _ = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + return testing.JitArguments( + initial_cute_tensors_abc_workspace[0], + initial_cute_tensors_abc_workspace[1], + initial_cute_tensors_abc_workspace[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc_workspace, + tensor_of_ptrs_abc_workspace, + tensormap_workspace, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + sum( + [ + sum( + [ + torch_tensor.numel() * torch_tensor.element_size() + for torch_tensor in group_tensors + ] + ) + for group_tensors in torch_tensors_abc + ] + ) + + + # Add size of strides tensor + tensor_of_strides_abc_torch.numel() + * tensor_of_strides_abc_torch.element_size() + + + # Add size of ptrs tensor + tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + + + # Add size of tensormap tensor + tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_grouped_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_comma_separated_tuples(s: str) -> List[tuple[int, ...]]: + if s.strip().startswith("("): + # Split on ),( to separate tuples + tuples = s.strip("()").split("),(") + result = [] + tuple_len = None + + for t in tuples: + # Parse individual tuple + nums = [int(x.strip()) for x in t.split(",")] + + # Validate tuple length consistency + if tuple_len is None: + tuple_len = len(nums) + elif len(nums) != tuple_len: + raise argparse.ArgumentTypeError( + "All tuples must have the same length" + ) + + result.append(tuple(nums)) + return result + + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers or list of tuples" + ) + + parser = argparse.ArgumentParser( + description="Example of Grouped GEMM on Blackwell." + ) + parser.add_argument( + "--num_groups", + type=int, + default=2, + help="Number of groups", + ) + parser.add_argument( + "--problem_sizes_mnkl", + type=parse_comma_separated_tuples, + default=((128, 128, 128, 1), (128, 128, 128, 1)), + help="a tuple of problem sizes for each group (comma-separated tuples)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--tensormap_update_mode", + type=str, + default="SMEM", + help="Tensor map update mode", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if ( + len(args.problem_sizes_mnkl) != 0 + and len(args.problem_sizes_mnkl) != args.num_groups + ): + parser.error("--problem_sizes_mnkl must contain exactly num_groups tuples") + + # l mode must be 1 for all groups + for _, _, _, l in args.problem_sizes_mnkl: + if l != 1: + parser.error("l must be 1 for all groups") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + if args.tensormap_update_mode not in ["GMEM", "SMEM"]: + parser.error("--tensormap_update_mode must be GMEM or SMEM") + + if args.tensormap_update_mode == "GMEM": + tensormap_update_mode = utils.TensorMapUpdateMode.GMEM + else: + tensormap_update_mode = utils.TensorMapUpdateMode.SMEM + + torch.manual_seed(2025) + + run( + args.num_groups, + args.problem_sizes_mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + tensormap_update_mode, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c6d0ae641fb5167fb5e9ed64cc150811dbf3aaf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/choices.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/choices.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cec4688dc3e5e2884862e7572cba21dc6c958f7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/choices.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f55d31d141b2d8c8c6a1819760833bbb36d11a26 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb419edec60d3e1bf782910cf3c3c42ba5b565f5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/package.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1669be8916cd38e7e16a34f720f875b1108dbcf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/package.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c2d6081e6c2605b99a0c3d9b1e582a05116267c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bdd9a8010418fc3186a48de3721da401aa186e5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2db814311940103eba3308a3dadabade8c82ced5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8afbb22c14d104dbaebe73933780f4b6a925efa8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb610a5cc4d74ad8eb7deca000cdb88746614ad3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..627991807c65fb2583359f71d298da6dd483799e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/debug_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6ff50fdaddbe829380980766f845af9e6b71ac5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/debug_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f38646ef4691d8378f713876a91ba8fcfb2a52d9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a96aa7237d1b4ab246af161c4014411790d0bcd7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a2a930bd6297e16fbfe31cce4b68153464a84fd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51c9febc403c5ffd1d1e44584ae749b3518c98a3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e625ea7ae6bfd3b0a09275db81c26eecc008307 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c40030120dc7b8e722062005beffa4e034a9d88c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1d364eaf51e009b557e422fe0b5093fe9cfb17 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__init__.py @@ -0,0 +1,68 @@ +from threading import Lock + +from . import config, interfaces as intfs, locks +from .context import IsolationSchema, SelectedCompileContext, SelectedRuntimeContext +from .exceptions import ( + CacheError, + CustomParamsEncoderRequiredError, + CustomResultDecoderRequiredError, + CustomResultEncoderRequiredError, + DeterministicCachingDisabledError, + DeterministicCachingIMCDumpConflictError, + DeterministicCachingInvalidConfigurationError, + DeterministicCachingRequiresStrongConsistencyError, + FileLockTimeoutError, + KeyEncodingError, + KeyPicklingError, + LockTimeoutError, + StrictDeterministicCachingKeyNotFoundError, + SystemError, + UserError, + ValueDecodingError, + ValueEncodingError, + ValuePicklingError, + ValueUnPicklingError, +) + + +# fast cache; does not bother supporting deterministic caching, and is essentially +# a memoized on-disk cache. use when deterministic caching is not required +fcache: intfs._CacheIntf = intfs._FastCacheIntf() +# deterministic cache; slower than fcache but provides deterministic guarantees. +# use when deterministic caching is absolutely required, as this will raise +# an exception if use is attempted when deterministic caching is disabled +dcache: intfs._CacheIntf = intfs._DeterministicCacheIntf() +# inductor cache; defaults to the deterministic cache if deterministic caching +# is enabled, otherwise uses the fast cache. use when you would like deterministic +# caching but are okay with non-deterministic caching if deterministic caching is disabled +icache: intfs._CacheIntf = ( + dcache if config.IS_DETERMINISTIC_CACHING_ENABLED() else fcache +) + +__all__ = [ + "SelectedCompileContext", + "SelectedRuntimeContext", + "IsolationSchema", + "CacheError", + "SystemError", + "UserError", + "LockTimeoutError", + "FileLockTimeoutError", + "KeyEncodingError", + "KeyPicklingError", + "ValueEncodingError", + "ValuePicklingError", + "ValueDecodingError", + "ValueUnPicklingError", + "CustomParamsEncoderRequiredError", + "CustomResultEncoderRequiredError", + "CustomResultDecoderRequiredError", + "DeterministicCachingDisabledError", + "DeterministicCachingRequiresStrongConsistencyError", + "StrictDeterministicCachingKeyNotFoundError", + "DeterministicCachingInvalidConfigurationError", + "DeterministicCachingIMCDumpConflictError", + "fcache", + "dcache", + "icache", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2b4ab03d4407eed8ba12ec09f5d2fa0ab0e0960 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bda57704ebbba48ffb4882efbd1c4c6f612a6c3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/context.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a4c24887652183c03c8da37350ed708ccf1d6ff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/context.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/exceptions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/exceptions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3962e40ad55f34e540023ad7daeff454fb11808 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/exceptions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/implementations.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/implementations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..460bf412fb93a67ef3adb8ada2ec10d552198336 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/implementations.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/interfaces.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/interfaces.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed7212177a6b805db0cd6030a767bb3083ff5d88 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/interfaces.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/locks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/locks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b6b6e2ae6a0fe6133efe31bc0c2556af185088 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/locks.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b5cf2ba1414c4f82d95cd95a0767396cc20917a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/config.py new file mode 100644 index 0000000000000000000000000000000000000000..14e13f937dbb75ad0b8ca0c197df3e8c2559c098 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/config.py @@ -0,0 +1,127 @@ +import os +from collections.abc import Callable +from functools import cache, partial + +import torch +from torch._environment import is_fbcode + + +@cache +def _env_var_config(env_var: str, default: bool) -> bool: + if (env_val := os.environ.get(env_var)) is not None: + return env_val == "1" + return default + + +@cache +def _versioned_config( + jk_name: str, + this_version: int, + oss_default: bool, + env_var_override: str | None = None, +) -> bool: + """ + A versioned configuration utility that determines boolean settings based on: + 1. Environment variable override (highest priority) + 2. JustKnobs version comparison in fbcode environments + 3. OSS default fallback + + This function enables gradual rollouts of features in fbcode by comparing + a local version against a JustKnobs-controlled remote version, while + allowing environment variable overrides for testing and OSS defaults + for non-fbcode environments. + + Args: + jk_name: JustKnobs key name (e.g., "pytorch/inductor:feature_version") + this_version: Local version number to compare against JustKnobs version + oss_default: Default value to use in non-fbcode environments + env_var_override: Optional environment variable name that, when set, + overrides all other logic + + Returns: + bool: Configuration value determined by the priority order above + """ + if ( + env_var_override + and (env_var_value := os.environ.get(env_var_override)) is not None + ): + return env_var_value == "1" + elif is_fbcode(): + # this method returns 0 on failure, which we should check for specifically. + # in the case of JK failure, the safe bet is to simply disable the config + jk_version: int = torch._utils_internal.justknobs_getval_int(jk_name) + return (this_version >= jk_version) and (jk_version != 0) + return oss_default + + +# toggles the entire caching module, but only when calling through the +# public facing interfaces. get/insert operations become no-ops in the sense +# that get will always miss and insert will never insert; record becomes a +# no-op in the sense that the function will always be called and the cache +# will never be accessed +_CACHING_MODULE_VERSION: int = 0 +_CACHING_MODULE_VERSION_JK: str = "pytorch/inductor:caching_module_version" +_CACHING_MODULE_OSS_DEFAULT: bool = False +_CACHING_MODULE_ENV_VAR_OVERRIDE: str = "TORCHINDUCTOR_ENABLE_CACHING_MODULE" +IS_CACHING_MODULE_ENABLED: Callable[[], bool] = partial( + _versioned_config, + _CACHING_MODULE_VERSION_JK, + _CACHING_MODULE_VERSION, + _CACHING_MODULE_OSS_DEFAULT, + _CACHING_MODULE_ENV_VAR_OVERRIDE, +) + + +# toggles the deterministic caching interface. silently disabling deterministic +# caching (i.e. by mimicking the functionality of IS_CACHING_MODULE_ENABLED) can +# be problematic if the user is directly calling the deterministic caching interface +# (for example, if they were to interface with dcache instead of icache). instead, if +# the user tries to use the deterministic caching interface while it is disabled we +# will simply throw DeterministicCachingDisabledError +_DETERMINISTIC_CACHING_VERSION: int = 0 +_DETERMINISTIC_CACHING_VERSION_JK: str = ( + "pytorch/inductor:deterministic_caching_version" +) +_DETERMINISTIC_CACHING_OSS_DEFAULT: bool = False +_DETERMINISTIC_CACHING_ENV_VAR_OVERRIDE: str = ( + "TORCHINDUCTOR_ENABLE_DETERMINISTIC_CACHING" +) +IS_DETERMINISTIC_CACHING_ENABLED: Callable[[], bool] = partial( + _versioned_config, + _DETERMINISTIC_CACHING_VERSION_JK, + _DETERMINISTIC_CACHING_VERSION, + _DETERMINISTIC_CACHING_OSS_DEFAULT, + _DETERMINISTIC_CACHING_ENV_VAR_OVERRIDE, +) + +# enabling strictly pre-populated determinism forces the deterministic caching +# interface to pull from and only from a pre-populated in-memory cache. this +# in-memory cache gets pre-populated from a file path that is specified by +# environment variable "TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE". +# coincidentally, the deterministic caching interface will dump its in-memory +# cache to disk on program exit (check the logs for the exact file path) which +# can be used as a drop-in solution for pre-population on subsequent runs. if +# strictly pre-populated determinism is enabled and a key is encountered which +# is not covered by the pre-populated in-memory cache an exception, +# StrictDeterministicCachingKeyNotFoundError, will be raised +STRICTLY_PRE_POPULATED_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_STRICTLY_PRE_POPULATED_DETERMINISM", + default=False, +) +# similar to strictly pre-populated determinism, except that any key can either +# be in the pre-populated in-memory cache or the on-disk/remote cache (depending +# on whether or not local/global determinism is enabled). +STRICTLY_CACHED_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_STRICTLY_CACHED_DETERMINISM", + default=False, +) +# local determinism ensures that caching is deterministic on a single machine, +# hence an on-disk cache is used for synchronization of results +LOCAL_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_LOCAL_DETERMINISM", default=(not is_fbcode()) +) +# global determinism ensures that caching is deterministic across any/all machines, +# hence a remote cache (with strong consistency!) is used for synchronization of results +GLOBAL_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_GLOBAL_DETERMINISM", default=is_fbcode() +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/context.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/context.py new file mode 100644 index 0000000000000000000000000000000000000000..7f52a70ff6d70982a5626a1ff48d7078b6b4ccf8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/context.py @@ -0,0 +1,292 @@ +"""Context management for PyTorch Inductor runtime caching. + +This module provides context classes for collecting configuration and environment +information used in caching decisions for PyTorch's Inductor runtime. +""" + +import json +from abc import ABC, abstractmethod +from base64 import b64encode +from collections.abc import Sequence +from functools import cache +from hashlib import sha256 +from typing import Any +from typing_extensions import override, TypedDict + +import torch + + +class _Context(ABC): + """Abstract base class for context providers. + + Context providers collect specific configuration and environment information + that affects compilation and runtime behavior. + """ + + @staticmethod + @abstractmethod + def forms_of_context() -> Sequence[str]: + """Return a sequence of context form names provided by this context class. + + Returns: + A sequence of strings representing the available context forms. + """ + + +class _RuntimeContext(_Context): + """Context provider for runtime configuration and environment settings. + + Collects configuration settings that affect runtime behavior but not + compilation, such as Inductor configs, determinism settings, and CUDA + matmul precision configurations. + """ + + @override + @staticmethod + def forms_of_context() -> Sequence[str]: + """Return the runtime context forms provided by this class. + + Returns: + A sequence containing the available runtime context forms: + - "inductor_configs": PyTorch Inductor configuration settings + - "torch_determinism_configs": Deterministic algorithm settings + - "cuda_matmul_precision_configs": CUDA matrix multiplication precision settings + """ + return ( + "inductor_configs", + "torch_determinism_configs", + "cuda_matmul_precision_configs", + ) + + @staticmethod + def inductor_configs() -> dict[str, Any]: + """Get portable Inductor configuration settings. + + Returns: + A dictionary containing Inductor configuration settings, + including private configs. + """ + from torch._inductor import config + + return config.save_config_portable(ignore_private_configs=False) + + @staticmethod + def torch_determinism_configs() -> dict[str, Any]: + """Get PyTorch deterministic algorithm configuration settings. + + Returns: + A dictionary containing deterministic algorithm settings: + - Whether deterministic algorithms are enabled + - Whether deterministic algorithm warnings are enabled + - Fill uninitialized memory setting + """ + return { + "torch.are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), + "torch.is_deterministic_algorithms_warn_only_enabled": ( + torch.is_deterministic_algorithms_warn_only_enabled() + ), + "torch.utils.deterministic.fill_uninitialized_memory": ( + torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] + ), + } + + @staticmethod + def cuda_matmul_precision_configs() -> dict[str, Any]: + """Get CUDA matrix multiplication precision configuration settings. + + Returns: + A dictionary containing CUDA matmul precision settings: + - FP32 precision setting + - FP16 reduced precision reduction allowance + - BF16 reduced precision reduction allowance + """ + return { + "torch.backends.cuda.matmul.fp32_precision": torch.backends.cuda.matmul.fp32_precision, + "torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction": ( + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + ), + "torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction": ( + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + ), + } + + +class _CompileContext(_Context): + """Context provider for compilation-related configuration and environment settings. + + Collects information that affects compilation behavior, such as PyTorch and Triton + versions, runtime environment, and accelerator properties. + """ + + @override + @staticmethod + def forms_of_context() -> Sequence[str]: + """Return the compile context forms provided by this class. + + Returns: + A sequence containing the available compile context forms: + - "torch_version_hash": PyTorch version hash + - "triton_version_hash": Triton version hash (if available) + - "runtime": Runtime type (CUDA/HIP/None) + - "runtime_version": Runtime version string + - "accelerator_properties": GPU/accelerator properties + """ + return ( + "torch_version_hash", + "triton_version_hash", + "runtime", + "runtime_version", + "accelerator_properties", + ) + + @cache + @staticmethod + def torch_version_hash() -> str: + """Get base64-encoded PyTorch version hash. + + Returns: + A base64-encoded string representing the PyTorch version hash. + """ + from torch._inductor.codecache import torch_key + + return b64encode(torch_key()).decode() + + @cache + @staticmethod + def triton_version_hash() -> str | None: + """Get Triton version key if Triton is available. + + Returns: + Triton version key if Triton is available, None otherwise. + """ + from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key + + return triton_key() if HAS_TRITON else None + + @cache + @staticmethod + def runtime() -> str | None: + """Determine the runtime type based on available backends. + + Returns: + "CUDA" if CUDA is available, "HIP" if HIP is available, None otherwise. + """ + return "CUDA" if torch.version.cuda else "HIP" if torch.version.hip else None + + @cache + @staticmethod + def runtime_version() -> str | None: + """Get the version string for the detected runtime. + + Returns: + Version string for the current runtime (CUDA or HIP), or None if + no supported runtime is detected. + """ + return { + "CUDA": torch.version.cuda, + "HIP": torch.version.hip, + }.get(_CompileContext.runtime()) # type: ignore[arg-type] + + @cache + @staticmethod + def accelerator_properties() -> str | None: + """Get string representation of CUDA device properties. + + Returns: + String representation of CUDA device properties if a runtime is + available, None otherwise. + """ + return ( + repr(torch.cuda.get_device_properties()) + if _CompileContext.runtime() and torch.cuda.is_available() + else None + ) + + +class SelectedRuntimeContext(TypedDict): + inductor_configs: bool + torch_determinism_configs: bool + cuda_matmul_precision_configs: bool + + +class SelectedCompileContext(TypedDict): + torch_version_hash: bool + triton_version_hash: bool + runtime: bool + runtime_version: bool + accelerator_properties: bool + + +class IsolationSchema(TypedDict): + """Schema for specifying which context forms to include in cache isolation. + + Attributes: + runtime_context: Either True (include all runtime context), False (exclude all), + or a SelectedRuntimeContext dict specifying which forms to include. + compile_context: Either True (include all compile context), False (exclude all), + or a SelectedCompileContext dict specifying which forms to include. + """ + + runtime_context: SelectedRuntimeContext | bool + compile_context: SelectedCompileContext | bool + + +_DEFAULT_ISOLATION_SCHEMA: IsolationSchema = IsolationSchema( + runtime_context=True, compile_context=True +) + + +def _isolation_context( + ischema: IsolationSchema = _DEFAULT_ISOLATION_SCHEMA, +) -> dict[str, Any]: + """Generate context data based on the isolation schema. + + Args: + ischema: Schema specifying which context forms to include. + Defaults to including all runtime and compile context. + + Returns: + A dictionary containing the selected context data with keys + "runtime_context" and "compile_context", where each value is + either None (if excluded) or a dict of context form data. + """ + isolation_context: dict[str, Any] = {} + for context_name, context_cls in ( + ("runtime_context", _RuntimeContext), + ("compile_context", _CompileContext), + ): + selected_context: dict[str, Any] | None = None + if ischema[context_name] is True: # type: ignore[literal-required] + selected_context = { + form_of_context: getattr(context_cls, form_of_context)() + for form_of_context in context_cls.forms_of_context() + } + elif ischema[context_name] is False: # type: ignore[literal-required] + selected_context = None + else: + selected_context = {} + for form_of_context in ischema[context_name]: # type: ignore[literal-required] + selected = ischema[context_name][form_of_context] # type: ignore[literal-required] + if selected: + selected_context[form_of_context] = getattr( + context_cls, form_of_context + )() + selected_context = selected_context or None + isolation_context[context_name] = selected_context + return isolation_context + + +def _isolation_key(ischema: IsolationSchema = _DEFAULT_ISOLATION_SCHEMA) -> str: + """Generate a unique key for the given isolation schema. + + Args: + ischema: Schema specifying which context forms to include. + Defaults to including all runtime and compile context. + + Returns: + A 32-character hexadecimal string that uniquely identifies + the context specified by the isolation schema. + """ + return sha256( + json.dumps(_isolation_context(ischema), sort_keys=True).encode() + ).hexdigest()[:32] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/exceptions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..02e47fa1e6127a44b45c61966d3aa6e3d9fb65da --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/exceptions.py @@ -0,0 +1,189 @@ +# pyre-strict + +"""Exception classes for PyTorch Inductor runtime caching. + +This module defines a hierarchy of exceptions used throughout the caching system. +All custom exceptions inherit from CacheError, with UserError serving as a base +for user-facing errors that also inherit from TypeError for compatibility. +""" + +from threading import Lock +from typing import Any + +from filelock import FileLock + + +class CacheError(Exception): + """Base class for all caching-related errors. + + This is the root exception class for all custom exceptions raised by the caching + module, providing a common interface for error handling and logging. + """ + + +class SystemError(CacheError, RuntimeError): + """Base class for system-level caching errors. + + This class represents errors that occur during cache operations, such as + storage or retrieval failures. It inherits from RuntimeError to indicate + that the error is not caused by user input. + """ + + +class LockTimeoutError(SystemError): + """Error raised when a lock operation times out. + + This exception is raised when a lock operation exceeds the specified timeout + limit, indicating that the lock could not be acquired within the allotted time. + """ + + def __init__(self, lock: Lock, timeout: float) -> None: + """Initialize the lock timeout error with detailed lock information. + + Args: + lock: The lock object that timed out. + timeout: The timeout limit that was exceeded. + """ + super().__init__(f"Failed to acquire lock {lock} within {timeout} seconds.") + + +class FileLockTimeoutError(SystemError): + """Error raised when a file lock operation times out. + + This exception is raised when a file lock operation exceeds the specified timeout + limit, indicating that the lock could not be acquired within the allotted time. + """ + + def __init__(self, flock: FileLock, timeout: float) -> None: + """Initialize the file lock timeout error with detailed lock information. + + Args: + flock: The file lock object that timed out. + timeout: The timeout limit that was exceeded. + """ + super().__init__( + f"Failed to acquire file lock {flock} within {timeout} seconds." + ) + + +class UserError(CacheError, TypeError): + """Base class for user-facing cache errors that also inherit from TypeError. + + This class combines CacheError with TypeError to provide compatibility + with existing exception handling patterns while maintaining the cache + error hierarchy. All user-facing cache errors should inherit from this class. + """ + + +class KeyEncodingError(UserError): + """Base class for errors that occur during cache key encoding operations. + + Raised when cache keys cannot be properly encoded for storage or transmission. + This includes serialization, hashing, or other encoding-related failures. + """ + + +class KeyPicklingError(KeyEncodingError): + """Error raised when a cache key cannot be pickled for serialization. + + This typically occurs when trying to cache objects with keys that contain + non-serializable components, lambda functions, or other unpickleable types. + """ + + def __init__(self, key: Any) -> None: + """Initialize the key pickling error with detailed key information. + + Args: + key: The cache key that failed to be pickled. + """ + super().__init__( + f"Failed to pickle cache key with type {type(key)} and value {key!r}." + ) + + +class ValueEncodingError(UserError): + """Base class for errors that occur during cache value encoding operations. + + Raised when cache values cannot be properly encoded for storage or transmission. + This includes serialization, compression, or other encoding-related failures. + """ + + +class ValuePicklingError(ValueEncodingError): + """Error raised when a cache value cannot be pickled for serialization. + + This occurs when trying to cache objects that contain non-serializable + components, file handles, network connections, or other unpickleable types. + """ + + def __init__(self, value: Any) -> None: + """Initialize the value pickling error with detailed value information. + + Args: + value: The cache value that failed to be pickled. + """ + super().__init__( + f"Failed to pickle cache value with type {type(value)} and value {value!r}." + ) + + +class ValueDecodingError(UserError): + """Base class for errors that occur during cache value decoding operations. + + Raised when cached values cannot be properly decoded during retrieval. + This includes deserialization, decompression, or other decoding-related failures. + """ + + +class ValueUnPicklingError(ValueDecodingError): + """Error raised when cached value data cannot be unpickled during retrieval. + + This typically indicates corruption, version incompatibility, or missing + dependencies required to reconstruct the cached object. + """ + + def __init__(self, pickled_value: bytes) -> None: + """Initialize the value unpickling error with the problematic data. + + Args: + pickled_value: The bytes that failed to be unpickled. + """ + super().__init__( + f"Failed to unpickle cache value from pickled value {pickled_value!r}." + ) + + +class CustomParamsEncoderRequiredError(UserError): + pass + + +class CustomResultEncoderRequiredError(UserError): + pass + + +class CustomResultDecoderRequiredError(UserError): + pass + + +class DeterministicCachingDisabledError(UserError): + pass + + +class DeterministicCachingRequiresStrongConsistencyError(UserError): + pass + + +class StrictDeterministicCachingKeyNotFoundError(UserError): + pass + + +class DeterministicCachingInvalidConfigurationError(UserError): + pass + + +class StrictDeterministicCachingInsertionError(UserError): + pass + + +class DeterministicCachingIMCDumpConflictError(SystemError): + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/implementations.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/implementations.py new file mode 100644 index 0000000000000000000000000000000000000000..ed83e490fd316059e7d877b63adb2eeaec69ed70 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/implementations.py @@ -0,0 +1,415 @@ +"""Cache implementation classes for PyTorch Inductor runtime caching. + +This module provides concrete implementations of caching backends including +in-memory, on-disk, and remote caching strategies. Each implementation follows +the abstract _CacheImpl interface and provides thread-safe operations with +appropriate locking mechanisms. +""" + +from abc import ABC, abstractmethod +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass +from hashlib import sha256 +from io import BufferedReader, BufferedWriter +from os import PathLike +from pathlib import Path +from threading import Lock +from typing import Any +from typing_extensions import override + +from filelock import FileLock + +from . import locks, utils + + +@dataclass +class Hit: + """Result wrapper for hits on cache get operations. + + Allows distinguishing between a cache miss and a cached None value. + + Attributes: + value: The cached value. + """ + + value: Any + + +class Miss: + """Sentinel class representing a cache miss. + + Used to distinguish between a cached None value and a cache miss + when None is a valid cached value. + """ + + +# Singleton instance for cache miss sentinel +miss = Miss() + + +class _CacheImpl(ABC): + """Abstract base class for cache implementations. + + This class defines the interface that all cache implementations must follow. + It provides thread-safe operations through a locking mechanism and supports + both get and insert operations. + + Note: We don't use generics here as doing so would require that the interfaces + know which k/v types the implementation can work with. Instead, we leave that + determination up to the implementation itself and require that the interfaces + handle any potential errors from invalid k/v types being passed to the + implementation. + """ + + def __init__(self) -> None: + """Initialize the cache implementation with a threading lock.""" + self._lock: Lock = Lock() + + @property + def lock(self) -> locks._LockProtocol: + """Get a context manager for acquiring the cache lock. + + Locking of the cache is not done by the implementation itself, but by the + interface that uses it. The interface may want to hold the lock for longer + than a single cache operation, for example when dealing with multiple + cache implementations at once, so we leave that decision up to the interface. + + Args: + timeout: Optional timeout in seconds (float) for acquiring the lock. + + Returns: + A callable that returns a context manager for the lock. + """ + + def _lock_with_timeout( + timeout: float | None = None, + ) -> locks._LockContextManager: + return locks._acquire_lock_with_timeout(self._lock, timeout) + + return _lock_with_timeout + + @abstractmethod + def get(self, key: Any) -> Hit | None: + """Retrieve a value from the cache. + + Args: + key: The key to look up in the cache. + + Returns: + A Hit object on cache hit where Hit.value is the cached value, + or None on cache miss. + """ + + @abstractmethod + def insert(self, key: Any, value: Any) -> bool: + """Insert a key-value pair into the cache. + + Args: + key: The key to insert. + value: The value to associate with the key. + + Returns: + True if the insertion was successful, False if not inserted. + """ + + +class _InMemoryCacheImpl(_CacheImpl): + """In-memory cache implementation using a dictionary. + + This implementation stores key-value pairs in a Python dictionary, + with keys being pickled for consistent hashing. It provides fast + access but is limited by available memory and process lifetime. + """ + + def __init__(self) -> None: + """Initialize the in-memory cache with an empty dictionary.""" + super().__init__() + self._memory: dict[bytes, Any] = {} + + @override + def get(self, key: Any) -> Hit | None: + """Retrieve a value from the in-memory cache. + + Args: + key: The key to look up. Will be pickled for storage. + + Returns: + A Hit object on cache hit where Hit.value is the cached value, + or None on cache miss. + """ + pickled_key: bytes = utils._try_pickle_key(key) + if (value := self._memory.get(pickled_key, miss)) is not miss: + return Hit(value=value) + return None + + @override + def insert(self, key: Any, value: Any) -> bool: + """Insert a key-value pair into the in-memory cache. + + Args: + key: The key to insert. Will be pickled for storage. + value: The value to associate with the key. + + Returns: + True if the insertion was successful (key was new), + False if not inserted (key already existed). + """ + pickled_key: bytes = utils._try_pickle_key(key) + if pickled_key not in self._memory: + self._memory[pickled_key] = value + return True + return False + + +class _OnDiskCacheImpl(_CacheImpl): + """On-disk cache implementation using file system storage. + + This implementation stores cached data as files on disk, with version + headers to handle cache invalidation. It uses file locking to ensure + thread safety across processes and provides persistent storage that + survives process restarts. + + Attributes: + _version: Version number for cache format compatibility. + _version_header_length: Length of the version header in bytes. + """ + + _version: int = 0 + _version_header_length: int = 4 + + def __init__(self, sub_dir: PathLike[str] | None = None) -> None: + """Initialize the on-disk cache with a specified subdirectory. + + Args: + sub_dir: Subdirectory name within the cache directory. + Defaults to empty string if not specified. + """ + self._cache_dir: Path = self._base_dir / (sub_dir or "") + # pyrefly: ignore [bad-assignment] + self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock")) + + @property + def _base_dir(self) -> Path: + """Get the base directory for cache storage. + + Returns: + Path to the cache directory based on the default cache dir + and the specified subdirectory. + """ + from torch._inductor.runtime.runtime_utils import default_cache_dir + + return Path(default_cache_dir(), "cache") + + def _fpath_from_key(self, key: Any) -> Path: + """Generate a file path from a cache key. + + Args: + key: The cache key to convert to a file path. + + Returns: + A Path object representing the file location for this key. + """ + pickled_key: bytes = utils._try_pickle_key(key) + return self._cache_dir / sha256(pickled_key).hexdigest()[:32] + + @classmethod + def _version_header(cls) -> bytes: + """Generate the version header bytes. + + Returns: + A byte string representing the current cache version header. + """ + return sha256(str(cls._version).encode()).digest()[: cls._version_header_length] + + def _version_header_matches(self, fp: BufferedReader) -> bool: + """Check if the file's version header matches the current version. + + Args: + fp: File pointer positioned at the start of the file. + + Returns: + True if the version header matches, False otherwise. + """ + return fp.read(self._version_header_length) == self._version_header() + + def _write_version_header(self, fp: BufferedWriter) -> None: + """Write the version header to a file. + + Args: + fp: File pointer where the version header should be written. + """ + fp.write(self._version_header()) + + @override + @property + def lock(self) -> locks._LockProtocol: + """Get a context manager for acquiring the file lock. + + Uses file locking to ensure thread safety across processes. + + Args: + timeout: Optional timeout in seconds (float) for acquiring the file lock. + + Returns: + A callable that returns a context manager for the file lock. + """ + + def _lock_with_timeout( + timeout: float | None = None, + ) -> locks._LockContextManager: + return locks._acquire_flock_with_timeout(self._flock, timeout) + + return _lock_with_timeout + + @override + def get(self, key: Any) -> Hit | None: + """Retrieve a value from the on-disk cache. + + Args: + key: The key to look up in the cache. + + Returns: + A Hit object on cache hit where Hit.value is the cached value, + or None on cache miss or version mismatch. + """ + fpath: Path = self._fpath_from_key(key) + + if not fpath.is_file(): + return None + + pickled_value: bytes | None = None + with open(fpath, "rb") as fp: + if self._version_header_matches(fp): + pickled_value = fp.read() + + if not pickled_value: + # if pickled_value is still None, even though the file exists, then + # we know that the version header did not match. in this case implementation + # is up to preference, we choose to remove entries that do not match + # the version header so that the key can be re-cached later with the correct + # version header + fpath.unlink() + return None + + return Hit(value=utils._try_unpickle_value(pickled_value)) + + @override + def insert(self, key: Any, value: Any) -> bool: + """Insert a key-value pair into the on-disk cache. + + Args: + key: The key to insert. + value: The value to associate with the key. + + Returns: + True if successfully inserted, False if the key already exists + with a valid version. + """ + fpath: Path = self._fpath_from_key(key) + fpath.parent.mkdir(parents=True, exist_ok=True) + + r_fp, w_fp, inserted = None, None, False + try: + w_fp = open(fpath, "xb") # noqa: SIM115 + except FileExistsError: + is_stale: bool = False + with open(fpath, "rb") as r_fp: + is_stale = not self._version_header_matches(r_fp) + + if is_stale: + # same story as above, in this case the version header doesn't + # match so we choose to remove the old entry so that the new + # k/v pair can be cached + fpath.unlink() + w_fp = open(fpath, "xb") # noqa: SIM115 + else: + w_fp = None + finally: + if w_fp: + try: + pickled_value: bytes = utils._try_pickle_value(value) + self._write_version_header(w_fp) + w_fp.write(pickled_value) + inserted = True + finally: + w_fp.close() + + return inserted + + +try: + from .fb.implementations import _RemoteCacheImpl +except ModuleNotFoundError: + + class _RemoteCacheImpl(_CacheImpl): # type: ignore[no-redef] + """Fallback remote cache implementation for non-Facebook environments. + + This is a no-op implementation that always raises NotImplementedError. + The actual remote cache implementation is provided in the `.fb` module + for Facebook-specific environments. + + Attributes: + _version: Version number for cache format compatibility. + has_strong_consistency: Whether the remote cache provides strong + consistency guarantees. + """ + + _version: int = 0 + has_strong_consistency: bool = False + + def __init__(self) -> None: + """Initialize the fallback remote cache implementation. + + Note: We don't need to initialize any form of lock since this + implementation provides a pseudo-lock context manager. + """ + + @override + @property + def lock(self) -> locks._LockProtocol: + """Get a pseudo lock that does nothing. + + Most remote cache implementations don't have an ability to implement + any form of locking, so we provide a no-op pseudo-lock for consistency + with the interface. + + Args: + timeout: Optional timeout in seconds (float). Ignored in this + + Returns: + A callable that returns a no-op context manager. + """ + + @contextmanager + def pseudo_lock( + timeout: float | None = None, + ) -> Generator[None, None, None]: + yield + + return pseudo_lock + + @override + def get(self, key: Any) -> Hit | None: + """Raise NotImplementedError for remote cache get operations. + + Args: + key: The key to look up (ignored). + + Raises: + NotImplementedError: Always raised as this is a fallback implementation. + """ + raise NotImplementedError + + @override + def insert(self, key: Any, value: Any) -> bool: + """Raise NotImplementedError for remote cache insert operations. + + Args: + key: The key to insert (ignored). + value: The value to insert (ignored). + + Raises: + NotImplementedError: Always raised as this is a fallback implementation. + """ + raise NotImplementedError diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/interfaces.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4b8251bc3997c6e03e742af55ad879266eaa73 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/interfaces.py @@ -0,0 +1,818 @@ +from __future__ import annotations + +import atexit +import json +import os +from abc import ABC, abstractmethod +from ast import literal_eval +from enum import Enum +from functools import partial, wraps +from logging import DEBUG, getLogger, INFO, Logger +from os import PathLike +from pathlib import Path +from threading import Lock +from time import time +from typing import Any, TYPE_CHECKING, TypeAlias +from typing_extensions import override + +from . import config, context, exceptions, implementations as impls, locks + + +if TYPE_CHECKING: + from collections.abc import Callable + + from .utils import P, R + + +# ideally we could annotate this as tuple[P.args, P.kwargs] but +# functionally that doesn't work as P is defined in a specific +# scope and P.args/P.kwargs are only valid in that scope +Params: TypeAlias = tuple[Any, Any] + +logger: Logger = getLogger(__name__) + + +class _IntfCallbackOrigin(Enum): + RECORD = "record" + GET = "get" + INSERT = "insert" + + +class _IntfCallbackAction(Enum): + REPLAY = "replay" + RECORD_INSERTED = "record_inserted" + RECORD_NOT_INSERTED = "record_not_inserted" + RECORD_NOT_INSERTED_REPLAY = "record_not_inserted_replay" + HIT = "hit" + MISS = "miss" + INSERTED = "inserted" + NOT_INSERTED = "not_inserted" + + +def _intf_callback( + origin: _IntfCallbackOrigin, + action: _IntfCallbackAction, + dur: float, + fn: Callable[P, R], + params: Params, + *args: Any, +) -> None: + if origin == _IntfCallbackOrigin.RECORD: + result: R = args[0] + if action == _IntfCallbackAction.REPLAY: + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r cached, " + "returned result %r in %f seconds.", + fn.__name__, + params, + result, + dur, + ) + elif action == _IntfCallbackAction.RECORD_INSERTED: + fn_dur: float = args[1] + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r not cached, " + "calculated and cached result %r in %f seconds " + "of which %f seconds was spent on the function call.", + fn.__name__, + params, + result, + dur, + fn_dur, + ) + elif action == _IntfCallbackAction.RECORD_NOT_INSERTED: + fn_dur = args[1] + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r not cached, " + "calculated result %r but was not able to " + "insert it into the cache as a matching " + "entry already exists; returned calculated result in %f seconds " + "of which %f seconds was spent on the function call.", + fn.__name__, + params, + result, + dur, + fn_dur, + ) + elif action == _IntfCallbackAction.RECORD_NOT_INSERTED_REPLAY: + fn_dur = args[1] + cached_result: R = args[2] + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r not cached, " + "calculated result %r but was not able to " + "insert it into the synchronization cache as a matching " + "entry already exists; returned cached result %r in %f seconds " + "of which %f seconds was spent on the function call.", + fn.__name__, + params, + result, + cached_result, + dur, + fn_dur, + ) + else: + raise NotImplementedError + elif origin == _IntfCallbackOrigin.GET: + if action == _IntfCallbackAction.HIT: + result = args[0] + logger.log( + DEBUG, + "[GET] for fn %s with params %r cached, " + "returned result %r in %f seconds.", + fn.__name__, + params, + result, + dur, + ) + elif action == _IntfCallbackAction.MISS: + logger.log( + DEBUG, + "[GET] for fn %s with params %r not cached, " + "returned nothing in %f seconds.", + fn.__name__, + params, + dur, + ) + else: + raise NotImplementedError + elif origin == _IntfCallbackOrigin.INSERT: + result = args[0] + if action == _IntfCallbackAction.INSERTED: + logger.log( + DEBUG, + "[INSERT] for fn %s with params %r and " + "result %r inserted in %f seconds.", + fn.__name__, + params, + result, + dur, + ) + elif action == _IntfCallbackAction.NOT_INSERTED: + logger.log( + DEBUG, + "[INSERT] for fn %s with params %r and " + "result %r not inserted in %f seconds as there is " + "already has a matching entry.", + fn.__name__, + params, + result, + dur, + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + +class _CacheIntf(ABC): + def __init__(self) -> None: + self._lock: Lock = Lock() + + def _make_key( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + ) -> Any: + callee: str = fn.__name__ + fkey: Any = ( + (callee, params) + if not custom_params_encoder + # pyrefly: ignore [invalid-param-spec] + else (callee, custom_params_encoder(*params[0], **params[1])) + ) + ikey: Any = context._isolation_key( + ischema if ischema is not None else context._DEFAULT_ISOLATION_SCHEMA + ) + return (fkey, ikey) + + def _make_dummy_record_wrapper(self, fn: Callable[P, R]) -> Callable[P, R]: + @wraps(fn) + def dummy_wrapper(*args: Any, **kwargs: Any) -> R: + # pyrefly: ignore [invalid-param-spec] + return fn(*args, **kwargs) + + # pyrefly: ignore [bad-return] + return dummy_wrapper + + @abstractmethod + def _make_record_wrapper( + self, + fn: Callable[P, R], + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> Callable[P, R]: + pass + + @abstractmethod + def _get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + pass + + @abstractmethod + def _insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + pass + + @property + def lock(self) -> locks._LockProtocol: + """Get a context manager for acquiring the file lock. + + Uses file locking to ensure thread safety across processes. + + Args: + timeout: Optional timeout in seconds (float) for acquiring the file lock. + + Returns: + A callable that returns a context manager for the file lock. + """ + + def _lock_with_timeout( + timeout: float | None = None, + ) -> locks._LockContextManager: + return locks._acquire_lock_with_timeout(self._lock, timeout) + + return _lock_with_timeout + + def get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + if not config.IS_CACHING_MODULE_ENABLED(): + return None + + start_t: float = time() + with self.lock(): # type: ignore[call-arg] + result: impls.Hit | None = self._get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + dur: float = time() - start_t + + _intf_callback( + _IntfCallbackOrigin.GET, + _IntfCallbackAction.HIT if result else _IntfCallbackAction.MISS, + dur, + fn, + params, + *((result.value,) if result else ()), + ) + + return result + + def insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + if not config.IS_CACHING_MODULE_ENABLED(): + return False + + start_t: float = time() + with self.lock(): # type: ignore[call-arg] + inserted: bool = self._insert( + fn, + params, + result, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + ) + dur: float = time() - start_t + + _intf_callback( + _IntfCallbackOrigin.INSERT, + _IntfCallbackAction.INSERTED + if inserted + else _IntfCallbackAction.NOT_INSERTED, + dur, + fn, + params, + result, + ) + + return inserted + + def record( + self, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[..., Any] | None = None, + custom_result_encoder: Callable[..., Any] | None = None, + custom_result_decoder: Callable[..., ...] | None = None, + ) -> Callable[[Callable[..., ...]], Callable[..., ...]]: + if custom_result_encoder and not custom_result_decoder: + raise exceptions.CustomResultDecoderRequiredError( + "Custom result encoder provided without custom result decoder." + ) + elif not custom_result_encoder and custom_result_decoder: + raise exceptions.CustomResultEncoderRequiredError( + "Custom result decoder provided without custom result encoder." + ) + elif not config.IS_CACHING_MODULE_ENABLED(): + return self._make_dummy_record_wrapper + else: + return partial( + self._make_record_wrapper, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + custom_result_decoder=custom_result_decoder, + ) + + +class _FastCacheIntf(_CacheIntf): + def __init__(self) -> None: + super().__init__() + self._imc: impls._InMemoryCacheImpl = impls._InMemoryCacheImpl() + self._callee_to_odc: dict[str, impls._OnDiskCacheImpl] = {} + + def _get_odc_from_callee(self, callee: str) -> impls._OnDiskCacheImpl: + if not (odc := self._callee_to_odc.get(callee)): + callee_sub_dir: PathLike[str] = Path(callee) + odc = impls._OnDiskCacheImpl(sub_dir=callee_sub_dir) + self._callee_to_odc[callee] = odc + # pyrefly: ignore [unbound-name] + return odc + + @override + def _make_record_wrapper( + self, + fn: Callable[P, R], + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> Callable[P, R]: + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + start_t: float = time() + params = ( + args, + kwargs, + ) + with self.lock(): + get: impls.Hit | None = self._get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + + if get: + dur: float = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.REPLAY, + dur, + fn, + params, + get.value, + ) + return get.value + else: + fn_start_t: float = time() + result: R = fn(*args, **kwargs) + fn_dur: float = time() - fn_start_t + inserted: bool = self._insert( + fn, + params, + result, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + ) + dur = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.RECORD_INSERTED + if inserted + else _IntfCallbackAction.RECORD_NOT_INSERTED, + dur, + fn, + params, + result, + fn_dur, + ) + return result + + return wrapper + + @override + def _get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + odc: impls._OnDiskCacheImpl = self._get_odc_from_callee(fn.__name__) + with locks._acquire_many_impl_locks_with_timeout(self._imc, odc): + try: + # we'll check the memoization first, since that is much faster + # than checking the on-disk cache (and the two should be consistent + # regardless) + imc_get: impls.Hit | None = self._imc.get(key) + if imc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(imc_get.value)) + else: + return imc_get + else: + odc_get: impls.Hit | None = odc.get(key) + if odc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(odc_get.value)) + return odc_get + return None + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + + @override + def _insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + odc: impls._OnDiskCacheImpl = self._get_odc_from_callee(fn.__name__) + with locks._acquire_many_impl_locks_with_timeout(self._imc, odc): + try: + encoded_result: Any = ( + result + if not custom_result_encoder + else custom_result_encoder(result) + ) + # reverse order of get, as we don't want to memoize values + # if we haven't actually inserted them into the on-disk cache + # so that the memoization and the on-disk cache remain consistent + if odc.insert(key, encoded_result): + assert self._imc.insert(key, encoded_result) + return True + return False + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + except exceptions.ValueEncodingError as err: + raise exceptions.CustomResultEncoderRequiredError( + f"Custom result encoder required for function {fn} with parameters {params} and result {result}." + ) from err + + +class _DeterministicCacheIntf(_CacheIntf): + def __init__(self) -> None: + super().__init__() + self._imc: impls._InMemoryCacheImpl = impls._InMemoryCacheImpl() + + if fpath_str := os.environ.get( + "TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE" + ): + fpath: Path = Path(fpath_str) + fpath_parent: PathLike[str] = fpath.parent + if fpath.is_file(): + odc: impls._OnDiskCacheImpl = impls._OnDiskCacheImpl( + sub_dir=fpath_parent + ) + with odc.lock(): + with open(fpath) as fp: + dump_for_pre_population: dict[str, str] = json.load(fp) + for key_r, value_r in dump_for_pre_population.items(): + key: bytes = literal_eval(key_r) + value: bytes = literal_eval(value_r) + self._imc._memory[key] = value + + if config.STRICTLY_PRE_POPULATED_DETERMINISM: + # we'll never need a synchronization cache if we're in strictly pre-populated mode, + # as we'll only ever be checking the memoized pre-population + self._get_sc_from_callee: Callable[ + [str], None | impls._OnDiskCacheImpl | impls._RemoteCacheImpl + ] = lambda callee: None + elif config.GLOBAL_DETERMINISM: + # if we want global determinism we need to use a remote cache with strong + # consistency as the synchronization cache + self._rc: impls._RemoteCacheImpl = impls._RemoteCacheImpl() + if not self._rc.has_strong_consistency: + raise exceptions.DeterministicCachingRequiresStrongConsistencyError + self._get_sc_from_callee = lambda callee: self._rc + elif config.LOCAL_DETERMINISM: + # local determinism can use the on-disk cache as the synchronization cache, + # for cleanliness of the on-disk cache we subdir based on the callee + self._callee_to_odc: dict[str, impls._OnDiskCacheImpl] = {} + self._get_sc_from_callee = self._get_odc_from_callee + else: + raise exceptions.DeterministicCachingInvalidConfigurationError( + "Deterministic caching must specify at least one of STRICTLY_PRE_POPULATED_DETERMINISM, " + "GLOBAL_DETERMINISM, or LOCAL_DETERMINISM." + ) + + atexit.register(self._dump_imc_to_disk) + + def __del__(self) -> None: + atexit.unregister(self._dump_imc_to_disk) + del self + + def _get_odc_from_callee(self, callee: str) -> impls._OnDiskCacheImpl: + if not (odc := self._callee_to_odc.get(callee)): + callee_sub_dir: PathLike[str] = Path(callee) + odc = impls._OnDiskCacheImpl(sub_dir=callee_sub_dir) + self._callee_to_odc[callee] = odc + # pyrefly: ignore [unbound-name] + return odc + + def _dump_imc_to_disk(self) -> Path | None: + with self.lock(): # type: ignore[call-arg] + to_dump: dict[str, str] = { + repr(key): repr(value) for key, value in self._imc._memory.items() + } + if not to_dump: + return None + + odc: impls._OnDiskCacheImpl = impls._OnDiskCacheImpl( + sub_dir=Path("dcache_dump") + ) + fpath: Path = odc._cache_dir / "imc.save" + with odc.lock(): + w_fp = None + try: + w_fp = open(fpath, "x") # noqa:SIM115 + except FileExistsError: + with open(fpath) as r_fp: + existing_dump = json.load(r_fp) + + for key, value in existing_dump.items(): + if key not in to_dump: + to_dump[key] = value + elif to_dump[key] != value: + raise exceptions.DeterministicCachingIMCDumpConflictError from None + + w_fp = open(fpath, "w") # noqa:SIM115 + finally: + assert w_fp is not None + try: + json.dump(to_dump, w_fp, indent=4) + logger.log( + INFO, "Dumped deterministic cache memoization to %s", fpath + ) + finally: + w_fp.close() + + return fpath + + @override + def _make_record_wrapper( + self, + fn: Callable[P, R], + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> Callable[P, R]: + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if not config.IS_DETERMINISTIC_CACHING_ENABLED(): + raise exceptions.DeterministicCachingDisabledError + start_t: float = time() + params = ( + args, + kwargs, + ) + with self.lock(): + get: impls.Hit | None = self._get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + + if get: + dur: float = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.REPLAY, + dur, + fn, + params, + get.value, + ) + return get.value + else: + fn_start_t: float = time() + result: R = fn(*args, **kwargs) + fn_dur: float = time() - fn_start_t + if not self._insert( + fn, + params, + result, + ischema, + custom_params_encoder, + custom_result_encoder, + ): + # if we couldn't insert that means that some other callee has populated + # the key entry in the remote cache within the time between our first get + # and the insert attempt; in that case, to be deterministic, we should + # call get again and return that value as the assumption is that other + # compile workers will also use that value + get = self._get( + fn, + params, + ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + assert get is not None, ( + "remote cache should get(key) if insert(key, _) failed" + ) + dur = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.RECORD_NOT_INSERTED_REPLAY, + dur, + fn, + params, + fn_dur, + get.value, + ) + return get.value + dur = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.RECORD_INSERTED, + dur, + fn, + params, + result, + fn_dur, + ) + return result + + return wrapper + + @override + def _get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + sc: impls._OnDiskCacheImpl | impls._RemoteCacheImpl | None = ( + self._get_sc_from_callee(fn.__name__) + ) + with locks._acquire_many_impl_locks_with_timeout( + *([self._imc, sc] if sc else [self._imc]) + ): + try: + # we'll check the memoization first, since that is much faster + # than checking the remote cache and the two should be consistent + imc_get: impls.Hit | None = self._imc.get(key) + if imc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(imc_get.value)) + else: + return imc_get + elif not sc: + raise exceptions.StrictDeterministicCachingKeyNotFoundError + else: + sc_get: impls.Hit | None = sc.get(key) + if sc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(sc_get.value)) + return sc_get + elif config.STRICTLY_CACHED_DETERMINISM: + raise exceptions.StrictDeterministicCachingKeyNotFoundError + return None + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + + @override + def _insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + if ( + config.STRICTLY_PRE_POPULATED_DETERMINISM + or config.STRICTLY_CACHED_DETERMINISM + ): + raise exceptions.StrictDeterministicCachingInsertionError + + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + sc: impls._OnDiskCacheImpl | impls._RemoteCacheImpl | None = ( + self._get_sc_from_callee(fn.__name__) + ) + assert sc, ( + "sc should be either an on-disk cache or a remote cache if we're inserting" + ) + with locks._acquire_many_impl_locks_with_timeout(self._imc, sc): + try: + encoded_result: Any = ( + result + if not custom_result_encoder + else custom_result_encoder(result) + ) + # reverse order of get, as we don't want to memoize values + # if we haven't actually inserted them into the remote cache + # so that the memoization and the remote cache remain consistent + if sc.insert(key, encoded_result): + if not self._imc.insert(key, encoded_result): + # imc might have the mapping already, if pre-populated + assert self._imc.get(key) == encoded_result + return True + return False + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + except exceptions.ValueEncodingError as err: + raise exceptions.CustomResultEncoderRequiredError( + f"Custom result encoder required for function {fn} with parameters {params} and result {result}." + ) from err + + @override + def get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + if not config.IS_DETERMINISTIC_CACHING_ENABLED(): + raise exceptions.DeterministicCachingDisabledError + return super().get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + + @override + def insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + if not config.IS_DETERMINISTIC_CACHING_ENABLED(): + raise exceptions.DeterministicCachingDisabledError + return super().insert( + fn, + params, + result, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/locks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/locks.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8cd011e2d443a814b01842db5677cab6e70132 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/locks.py @@ -0,0 +1,202 @@ +"""Lock acquisition utilities for caching system with timeout support. + +This module provides safe and unsafe lock acquisition functions for both threading.Lock +and FileLock objects, with configurable timeout behaviors. It supports three timeout modes: +blocking (infinite wait), non-blocking (immediate), and blocking with timeout (finite wait). + +The module offers both context manager and manual acquisition patterns: +- Safe acquisition: Uses context managers that automatically handle lock release +- Unsafe acquisition: Manual acquisition that requires explicit release by the caller +""" + +from __future__ import annotations + +from contextlib import _GeneratorContextManager, contextmanager, ExitStack +from typing import TYPE_CHECKING, TypeAlias +from typing_extensions import Protocol + +from filelock import FileLock, Timeout + +from . import exceptions, implementations as impls + + +if TYPE_CHECKING: + from collections.abc import Generator + from threading import Lock + + +_LockContextManager: TypeAlias = _GeneratorContextManager[None, None, None] + + +class _LockProtocol(Protocol): # noqa: PYI046 + def __call__(self, timeout: float | None = None) -> _LockContextManager: ... + + +# Infinite timeout - blocks indefinitely until lock is acquired. +_BLOCKING: float = -1 +# No timeout - returns immediately if lock cannot be acquired. +_NON_BLOCKING: float = 0 +# Finite timeout - blocks for a specified duration before raising a timeout error. +_BLOCKING_WITH_TIMEOUT: float = 60.0 +# Default timeout for lock acquisition. +_DEFAULT_TIMEOUT: float = _BLOCKING_WITH_TIMEOUT + + +@contextmanager +def _acquire_lock_with_timeout( + lock: Lock, + timeout: float | None = None, +) -> Generator[None, None, None]: + """Context manager that safely acquires a threading.Lock with timeout and automatically releases it. + + This function provides a safe way to acquire a lock with timeout support, ensuring + the lock is always released even if an exception occurs during execution. + + Args: + lock: The threading.Lock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Yields: + None: Yields control to the caller while holding the lock + + Raises: + LockTimeoutError: If the lock cannot be acquired within the timeout period + + Example: + with _acquire_lock_with_timeout(my_lock, timeout=30.0): + # Critical section - lock is held + perform_critical_operation() + # Lock is automatically released here + """ + _unsafe_acquire_lock_with_timeout(lock, timeout=timeout) + + try: + yield + finally: + lock.release() + + +def _unsafe_acquire_lock_with_timeout(lock: Lock, timeout: float | None = None) -> None: + """Acquire a threading.Lock with timeout without automatic release (unsafe). + + This function acquires a lock with timeout support but does NOT automatically + release it. The caller is responsible for releasing the lock explicitly. + Use this only when you need manual control over lock lifetime. + + Args: + lock: The threading.Lock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Raises: + LockTimeoutError: If the lock cannot be acquired within the timeout period + + Warning: + This is an "unsafe" function because it does not automatically release + the lock. Always call lock.release() when done, preferably in a try/finally + block or use the safe _acquire_lock_with_timeout context manager instead. + + Example: + lock = Lock() + try: + _unsafe_acquire_lock_with_timeout(lock, timeout=30.0) + # Critical section - lock is held + perform_critical_operation() + finally: + lock.release() # Must manually release! + """ + _timeout: float = timeout if timeout is not None else _DEFAULT_TIMEOUT + if not lock.acquire(timeout=_timeout): + raise exceptions.LockTimeoutError(lock, _timeout) + + +@contextmanager +def _acquire_flock_with_timeout( + flock: FileLock, + timeout: float | None = None, +) -> Generator[None, None, None]: + """Context manager that safely acquires a FileLock with timeout and automatically releases it. + + This function provides a safe way to acquire a file lock with timeout support, ensuring + the lock is always released even if an exception occurs during execution. + + Args: + flock: The FileLock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Yields: + None: Yields control to the caller while holding the file lock + + Raises: + FileLockTimeoutError: If the file lock cannot be acquired within the timeout period + + Example: + flock = FileLock("/tmp/my_process.lock") + with _acquire_flock_with_timeout(flock, timeout=30.0): + # Critical section - file lock is held + perform_exclusive_file_operation() + # File lock is automatically released here + """ + _unsafe_acquire_flock_with_timeout(flock, timeout=timeout) + + try: + yield + finally: + flock.release() + + +def _unsafe_acquire_flock_with_timeout(flock: FileLock, timeout: float | None) -> None: + """Acquire a FileLock with timeout without automatic release (unsafe). + + This function acquires a file lock with timeout support but does NOT automatically + release it. The caller is responsible for releasing the lock explicitly. + Use this only when you need manual control over lock lifetime. + + Args: + flock: The FileLock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Raises: + FileLockTimeoutError: If the file lock cannot be acquired within the timeout period + + Warning: + This is an "unsafe" function because it does not automatically release + the lock. Always call flock.release() when done, preferably in a try/finally + block or use the safe _acquire_flock_with_timeout context manager instead. + + Example: + flock = FileLock("/tmp/my_process.lock") + try: + _unsafe_acquire_flock_with_timeout(flock, timeout=30.0) + # Critical section - file lock is held + perform_exclusive_file_operation() + finally: + flock.release() # Must manually release! + """ + _timeout: float = timeout if timeout is not None else _DEFAULT_TIMEOUT + try: + _ = flock.acquire(timeout=_timeout) + except Timeout as err: + raise exceptions.FileLockTimeoutError(flock, _timeout) from err + + +@contextmanager +def _acquire_many_impl_locks_with_timeout( + *impls: impls._CacheImpl, + timeout: float | None = None, +) -> Generator[None, None, None]: + with ExitStack() as stack: + for impl in impls: + stack.enter_context(impl.lock(timeout)) + yield diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb25573f2e37346d2f16501f4fb6ff731353cef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/caching/utils.py @@ -0,0 +1,109 @@ +"""Utility functions for caching operations in PyTorch Inductor runtime. + +This module provides helper functions for pickling/unpickling operations +with error handling, LRU caching decorators, and type-safe serialization +utilities used throughout the caching system. +""" + +import pickle +from collections.abc import Callable +from functools import lru_cache, partial, wraps +from typing import Any +from typing_extensions import ParamSpec, TypeVar + +from . import exceptions + + +# Type specification for function parameters +P = ParamSpec("P") +# Type variable for function return values +R = TypeVar("R") + + +def _lru_cache(fn: Callable[P, R]) -> Callable[P, R]: + """LRU cache decorator with TypeError fallback. + + Provides LRU caching with a fallback mechanism that calls the original + function if caching fails due to unhashable arguments. Uses a cache + size of 64 with typed comparison. + + Args: + fn: The function to be cached. + + Returns: + A wrapper function that attempts caching with fallback to original function. + """ + cached_fn = lru_cache(maxsize=64, typed=True)(fn) + + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[type-var] + try: + return cached_fn(*args, **kwargs) # type: ignore[arg-type] + except TypeError: + return fn(*args, **kwargs) + + return wrapper + + +@_lru_cache +def _try_pickle(to_pickle: Any, raise_if_failed: type = exceptions.CacheError) -> bytes: + """Attempt to pickle an object with error handling. + + Tries to serialize an object using pickle.dumps with appropriate error + handling and custom exception raising. + + Args: + to_pickle: The object to be pickled. + raise_if_failed: Exception class to raise if pickling fails. + + Returns: + The pickled bytes representation of the object. + + Raises: + The exception class specified in raise_if_failed if pickling fails. + """ + try: + pickled: bytes = pickle.dumps(to_pickle) + except (pickle.PicklingError, AttributeError) as err: + raise raise_if_failed(to_pickle) from err + return pickled + + +# Specialized pickle function for cache keys with KeyPicklingError handling. +_try_pickle_key: Callable[[Any], bytes] = partial( + _try_pickle, raise_if_failed=exceptions.KeyPicklingError +) +# Specialized pickle function for cache values with ValuePicklingError handling. +_try_pickle_value: Callable[[Any], bytes] = partial( + _try_pickle, raise_if_failed=exceptions.ValuePicklingError +) + + +@_lru_cache +def _try_unpickle(pickled: bytes, raise_if_failed: type = exceptions.CacheError) -> Any: + """Attempt to unpickle bytes with error handling. + + Tries to deserialize bytes using pickle.loads with appropriate error + handling and custom exception raising. + + Args: + pickled: The bytes to be unpickled. + raise_if_failed: Exception class to raise if unpickling fails. + + Returns: + The unpickled object. + + Raises: + The exception class specified in raise_if_failed if unpickling fails. + """ + try: + unpickled: Any = pickle.loads(pickled) + except pickle.UnpicklingError as err: + raise raise_if_failed(pickled) from err + return unpickled + + +# Specialized unpickle function for cache keys with KeyUnPicklingError handling. +_try_unpickle_value: Callable[[Any], bytes] = partial( + _try_unpickle, raise_if_failed=exceptions.ValueUnPicklingError +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3d731525ea8d1bebac20a4b2e9ac732469cdd4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__init__.py @@ -0,0 +1,6 @@ +# NOTE: add new template heuristics here, so they get imported and registered +# TODO: write a simple glob if there are many heuristics to auto import them in the right order +from . import aten, base, contiguous_mm, decompose_k, registry, triton + +# expose the entry function +from .registry import get_template_heuristic diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/aten.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/aten.py new file mode 100644 index 0000000000000000000000000000000000000000..103668aa056faae96c6e65ef9a8d912ef6543c6e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/aten.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from torch._inductor import config as inductor_config + +from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype +from ..kernel.mm import ( + aten__fp8_mm, + aten__int_mm, + aten_addmm, + aten_bias_addmm, + aten_mm, + aten_mm_dtype, +) +from ..kernel.mm_plus_mm import aten_mm_plus_mm +from .base import TemplateConfigHeuristics +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Generator + + from ..kernel_inputs import KernelInputs + + +# These are all labeled as device type None to indicate that they +# are valid for all device types +@register_template_heuristic(aten_mm.uid, None) +@register_template_heuristic(aten_mm_dtype.uid, "cuda") +@register_template_heuristic(aten__fp8_mm.uid, None) +@register_template_heuristic(aten__int_mm.uid, None) +@register_template_heuristic(aten_bmm.uid, None) +@register_template_heuristic(aten_mm_plus_mm.uid, None) +# bmm dtype is only valid on cuda +@register_template_heuristic(aten_bmm_dtype.uid, "cuda") +class ATenConfigHeuristics(TemplateConfigHeuristics): + """ + Pseudo heuristic to make ATen choices go through the same flow as other templates + + This is a single choice without kwargs + + If you want to use this with an ATen choice that has kwargs, just subclass + """ + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + yield dict() + + +# None here indicates that this is valid for all device types on that op +# Note (None, op) takes precedence over (device_type, None) +@register_template_heuristic(aten_addmm.uid, None, op_name="addmm") +@register_template_heuristic(aten_baddbmm.uid, None, op_name="baddbmm") +class ATenAddMMConfigHeuristics(ATenConfigHeuristics): + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + alpha = kernel_inputs.get_scalar("alpha") + beta = kernel_inputs.get_scalar("beta") + return { + **kwargs, + "alpha": alpha, + "beta": beta, + } + + +@register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm") +class ATenBiasAddMMConfigHeuristics( + ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics +): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + nodes = kernel_inputs.nodes() + # for addmm, bias is the first input + bias = nodes[0] + if bias.get_stride()[0] == 0 and inductor_config.triton.autotune_cublasLt: + yield dict() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0343270f3a1111de9963f2dfb4781b7aabd1d855 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/base.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from .params import DictKernelTemplateParams, KernelTemplateParams + + +if TYPE_CHECKING: + from collections.abc import Generator + + from ..kernel_inputs import KernelInputs + + +class TemplateConfigHeuristics: + """Base class for generating sets of configs for an associated template.""" + + def should_run(self, inputs: KernelInputs) -> bool: + """ + hookup to check whether the configs are right to run at all e.g. you can check + max-autotune specific to your heuristic here or other things + If this returns False, get_template_configs will yield no configs + + Args: + inputs: KernelInputs + """ + return True + + def get_template_configs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[KernelTemplateParams, None, None]: + """ + Get template configs for the given inputs. + + Prefer to override the _get_template_configs_impl method + to leverage things like should_run + """ + if not self.should_run(kernel_inputs): + return + + # Generate configs and fuse with extra_kwargs + for config_dict in self._get_template_configs_impl(kernel_inputs, op_name): + # Fuse extra_kwargs into config + yield DictKernelTemplateParams(config_dict) + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get template configs for the given inputs. + This is the main entry point for template-specific logic. + """ + # base implementation yields no entries + yield from [] + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + """ + Get extra kwargs for the given inputs/op for the template. + + Use this to return kwargs that are needed for the template, but + do not change depending on the config/choice, but are rather + always the same, for all configs + """ + return {} + + def adjust_kernel_inputs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> KernelInputs: + """ + Adjust kernel inputs for the given inputs/op for the template. + + override this to adjust the kernel inputs e.g. (un)squeezing + """ + return kernel_inputs diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/contiguous_mm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/contiguous_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b65eba9c76cfbaa23d67cca2a6fb0d51d317dc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/contiguous_mm.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import torch + +from ..ir import get_free_symbols +from ..kernel.mm import ( + addmm_contiguous_subgraph_template, + mm_contiguous_subgraph_template, +) +from ..kernel_inputs import KernelInputs, MMKernelInputs +from ..utils import use_contiguous +from .base import TemplateConfigHeuristics +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Generator + + +@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm") +@register_template_heuristic( + addmm_contiguous_subgraph_template.uid, None, op_name="addmm" +) +class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics): + """empty heuristics to skip contiguous mm on not cuda""" + + +@register_template_heuristic( + mm_contiguous_subgraph_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="mm", +) +@register_template_heuristic( + addmm_contiguous_subgraph_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="addmm", +) +class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get all the valid k_splits for the given m, n, k. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + # Check for unbacked symbols - if found, yield nothing + unbacked_symbols = any( + len(get_free_symbols(itr, unbacked_only=True)) > 0 + for itr in ( + *kernel_inputs.shapes_symbolic(), + *kernel_inputs.strides_symbolic(), + ) + ) + if unbacked_symbols: + return + mat2 = kernel_inputs.mat1mat2()[1] + if mat2.get_layout().is_contiguous(): + # no need for contiguous decomposition + return + m, n, k = kernel_inputs.mnk_symbolic() + if not use_contiguous(m, n, k): + return + yield {} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/cutedsl.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/cutedsl.py new file mode 100644 index 0000000000000000000000000000000000000000..db337b9d8a271d25f28c55a23aaa2dc91e56b0bf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/cutedsl.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import auto, Enum +from itertools import product + +import torch._inductor.config as config + + +class TensorMapUpdateMode(Enum): + """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" + + SMEM = auto() + GMEM = auto() + + +@dataclass(frozen=True) +class CuTeGemmConfig: + TILE_M: int = 128 + TILE_N: int = 192 + CLUSTER_M: int = 2 + CLUSTER_N: int = 1 + USE_2_CTA: bool = False + TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM + + +def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + For information regarding valid config sets, see: + https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py + """ + + # Tile_n is always the same regardless of 2cta + tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] + + # Valid clusters + clusters_no_2cta = [ + (1, 1), + (1, 2), + (1, 4), + (1, 8), + (1, 16), + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + clusters_2cta = [ + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + + configs: list[CuTeGemmConfig] = [] + + for use_2cta, cluster_set, tile_m_range in [ + (False, clusters_no_2cta, [64, 128]), + (True, clusters_2cta, [128, 256]), + ]: + for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( + [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], + tile_m_range, + tile_n_vals, + cluster_set, + ): + configs.append( + CuTeGemmConfig( + tile_m, + tile_n, + cluster_m, + cluster_n, + USE_2_CTA=use_2cta, + TENSORMAP_UPDATE_MODE=tensormap_update_mode, + ) + ) + + return configs + + +def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + """ + + config_tuples = [ + (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), + (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), + (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), + (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), + (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), + (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + ] + + return [CuTeGemmConfig(*args) for args in config_tuples] + + +def get_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + + Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures + or unstable results. By default, autotuning is disabled and we return only + a single baseline config. + """ + if ( + config.cutedsl_enable_autotuning + and config.max_autotune_gemm_search_space == "EXHAUSTIVE" + ): + return get_exhaustive_groupgemm_configs() + elif config.cutedsl_enable_autotuning: + return get_default_groupgemm_configs() + else: + return [get_default_groupgemm_configs()[0]] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/decompose_k.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/decompose_k.py new file mode 100644 index 0000000000000000000000000000000000000000..7954396a10861b39748ad73075b343286551a102 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/decompose_k.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import sympy + +import torch + +from ..ir import get_free_symbols +from ..kernel.mm import decompose_k_subgraph_template +from ..kernel_inputs import KernelInputs, MMKernelInputs +from ..utils import get_k_splits +from ..virtualized import V +from .base import TemplateConfigHeuristics +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Generator + + +@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm") +class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics): + """empty heuristics to skip decompose k on anything not cuda""" + + +# on CUDA, we don't support hip for decompose_k yet +@register_template_heuristic( + decompose_k_subgraph_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="mm", +) +# TODO(coconutruben): enable decompose k on AMD by removing the register bool +# and benchmarking it for performance and stability +# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia) +# by either adding specific register_template_heuristic tags, or setting the +# device to None (enabled on all devices) +class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get all the valid k_splits for the given m, n, k. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + + # Check for unbacked symbols - if found, yield nothing + unbacked_symbols = any( + len(get_free_symbols(itr, unbacked_only=True)) > 0 + for itr in ( + *kernel_inputs.shapes_symbolic(), + *kernel_inputs.strides_symbolic(), + ) + ) + if unbacked_symbols: + return + + m, n, k = kernel_inputs.mnk_symbolic() + k_splits = get_k_splits(m, n, k) + for k_split in k_splits: + if not V.graph.sizevars.statically_known_true( + sympy.Eq(sympy.Mod(k, k_split), 0) + ): + continue + yield {"k_split": k_split} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/gemm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..2d56f4c481ccd0601d75b8867a48634c7001abc3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/gemm.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .. import config as inductor_config +from .base import TemplateConfigHeuristics + + +if TYPE_CHECKING: + from ..kernel_inputs import KernelInputs + + +class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics): + def should_run(self, inputs: KernelInputs) -> bool: + """ + simple base override for GEMM family templates that run only in max-autotune + """ + return inductor_config.max_autotune or inductor_config.max_autotune_gemm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/params.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/params.py new file mode 100644 index 0000000000000000000000000000000000000000..92b130217e3d19507b51e7bd384072548c67abb4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/params.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class KernelTemplateParams(ABC): + """Abstract base class for kernel template parameters.""" + + @abstractmethod + def to_kwargs(self) -> dict[str, Any]: + """Convert params to kwargs dict for template.choice_or_none()""" + + @abstractmethod + def to_serializeable_dict(self) -> dict[str, Any]: + """Convert params to serializable dict for storage/caching""" + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> KernelTemplateParams: + """Create params instance from dict""" + + +class DictKernelTemplateParams(KernelTemplateParams): + """Simple implementation that wraps a kwargs dict""" + + # NOTE: this is a compatibility layer, until every template + # has time to define their own params class, with meaningful + # defaults etc. + + def __init__(self, kwargs: dict[str, Any]): + self.kwargs = kwargs + + def to_kwargs(self) -> dict[str, Any]: + return self.kwargs.copy() + + def to_serializeable_dict(self) -> dict[str, Any]: + return self.kwargs.copy() + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> DictKernelTemplateParams: + return cls(data) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/registry.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..247c78fd557580e33474c8550e645c372db49903 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/registry.py @@ -0,0 +1,175 @@ +""" +Template heuristic registry system for PyTorch Inductor. + +This module provides a centralized registration system for template heuristics, +allowing automatic registration based on device type and conditional registration +for CUDA vs ROCm based on torch.version.hip. +""" + +from __future__ import annotations + +import contextlib +import logging +from typing import Any, Optional, TYPE_CHECKING, Union + +from .base import TemplateConfigHeuristics + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +# Module-wide registry for template heuristics +_TEMPLATE_HEURISTIC_REGISTRY: dict[ + tuple[Union[str, None], ...], type[TemplateConfigHeuristics] +] = {} + +# Manual cache for successful lookups only (fallback instances are not cached) +_HEURISTIC_CACHE: dict[tuple[str, str, str], TemplateConfigHeuristics] = {} + +log = logging.getLogger(__name__) + + +def register_template_heuristic( + template_name: str, + device_type: Union[str, None], + register: bool = True, + op_name: Optional[str] = None, +) -> Any: + """ + Decorator to register template heuristic classes. + + Args: + template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") + device_type: Device type ("cuda", "cpu", "xpu") + Set this to None to indicate that the heuristic is applicable to all device types. + register: Whether to register this heuristic. Caller should pass the condition directly. + op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional + and is only used when a template uses different heuristics for different ops + + Returns: + Decorator function that registers the class if conditions are met. + + Example: + @register_template_heuristic("mm", "cuda", register=torch.version.hip is None) + class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + pass + """ + + def decorator( + cls: type[TemplateConfigHeuristics], + ) -> type[TemplateConfigHeuristics]: + if register: + key: tuple[Union[str, None], ...] = (template_name, device_type, op_name) + _TEMPLATE_HEURISTIC_REGISTRY[key] = cls + log.info( + f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004 + ) + return cls + + return decorator + + +def get_template_heuristic( + template_name: str, device_type: str, op_name: str +) -> TemplateConfigHeuristics: + """ + Retrieve a template heuristic instance for the given template and device type. + + Args: + template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") + device_type: Device type ("cuda", "cpu", "xpu") + op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm") + + Returns: + Template heuristic instance. If no specific heuristic is found, + returns a fallback TemplateConfigHeuristics() instance (uncached). + """ + # Check cache first + cache_key = (template_name, device_type, op_name) + if cache_key in _HEURISTIC_CACHE: + return _HEURISTIC_CACHE[cache_key] + + keys = [ + # everything is specified + (template_name, device_type, op_name), + # heuristic is valid across all devices + (template_name, None, op_name), + # heuristic is valid across all ops for that device + (template_name, device_type, None), + # heuristic is always valid for that template + (template_name, None, None), + ] + + # Look up in registry + heuristic_class = None + for key in keys: + if key in _TEMPLATE_HEURISTIC_REGISTRY: + heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key] + break + + if heuristic_class is None: + # Log error and return fallback instance (uncached) + log.error( + "No template heuristic found - template_name=%s, device_type=%s, op_name=%s. " + "Available combinations: %s. Using fallback TemplateConfigHeuristics instance.", + template_name, + device_type, + op_name, + list(_TEMPLATE_HEURISTIC_REGISTRY.keys()), + ) + return TemplateConfigHeuristics() + + # Cache successful lookup and return + instance = heuristic_class() + _HEURISTIC_CACHE[cache_key] = instance + return instance + + +def clear_registry() -> None: + """ + Clear all registered template heuristics. + + This is primarily useful for testing purposes to ensure a clean state. + """ + _TEMPLATE_HEURISTIC_REGISTRY.clear() + _HEURISTIC_CACHE.clear() + + +@contextlib.contextmanager +def override_template_heuristics( + device_type: str, + template_op_pairs: list[tuple[str, str]], +) -> Iterator[None]: + """ + Context manager to temporarily override template heuristics with an empty heuristic. + + This is useful for testing purposes, where we want to ensure a specific template/op pair + is not used + + Args: + device_type: Device type ("cuda", "cpu", "xpu") + template_op_pairs: List of (template_name, op_name) pairs to override. + """ + # Save original entries to restore later + original_entries = {} + new_keys = [] + _HEURISTIC_CACHE.clear() + try: + for template_name, op_name in template_op_pairs: + assert op_name is not None + key = (device_type, template_name, op_name) + if key in _TEMPLATE_HEURISTIC_REGISTRY: + original_entries[key] = _TEMPLATE_HEURISTIC_REGISTRY[key] + # TemplateConfigHeuristics base class returns no entries + # so we use it for overriding + _TEMPLATE_HEURISTIC_REGISTRY[key] = TemplateConfigHeuristics + new_keys.append(key) + yield + finally: + # Restore original entries or remove if they didn't exist before + for key in new_keys: + _TEMPLATE_HEURISTIC_REGISTRY.pop(key, None) + if key in original_entries: + _TEMPLATE_HEURISTIC_REGISTRY[key] = original_entries[key] + _HEURISTIC_CACHE.clear() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..21deda557346b8adda8668699120854e705e524e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton.py @@ -0,0 +1,2649 @@ +from __future__ import annotations + +import dataclasses +import itertools +import math +import os +from functools import partial +from threading import Lock +from typing import Any, Optional, TYPE_CHECKING + +import sympy + +import torch +from torch._inductor.template_heuristics.triton_addmm import AddMMConfigMixin +from torch.utils._ordered_set import OrderedSet +from torch.utils._triton import has_triton_stable_tma_api + +from .. import config, config as inductor_config +from ..kernel.bmm import bmm_template +from ..kernel.mm import ( + blackwell_ws_persistent_device_tma_mm_template, + get_scaling_options, + get_tile_size, + mm_template, + persistent_tma_mm_template, + scaled_mm_device_tma_epilogue_scaling_template, + scaled_mm_device_tma_main_loop_scaling_template, +) +from ..kernel.mm_plus_mm import mm_plus_mm_template +from ..kernel_inputs import KernelInputs, MMKernelInputs +from ..utils import ( + get_backend_num_stages, + get_num_sms, + get_tma_workspace_arg, + TMA_DESCRIPTOR_SIZE, + using_b200, +) +from ..virtualized import V +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + + from triton import Config as TritonConfig + + +# Gemm Configs +@dataclasses.dataclass +class BaseConfig: + """ + Base Gemm configuration used for most backends (CPU, CUDA) + """ + + block_m: int + block_n: int + block_k: int + num_stages: int + num_warps: int + hint_override: Optional[int] = dataclasses.field(kw_only=True, default=None) + + +@dataclasses.dataclass +class GemmConfig(BaseConfig): + """ + Gemm configuration used for most backends (CPU, CUDA) + """ + + group_m: int = dataclasses.field(kw_only=True, default=8) + + +ConvConfig = BaseConfig + + +# FlexAttention Configs +@dataclasses.dataclass +class FlexConfig: + """ + Base Config class for flex attention + - FlexAttn forward and backward will use this. For flex decoding, + please use FlexDecodingConfig. + + NOTE: + For flex_attn bwd block_m and block_n are reused for block_m1, block_m2, block_n1, block_n2 + + """ + + block_m: int + block_n: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexBwDConfig: + """ + Base Config class for flex attention backward + - FlexAttn backward will use this. + + Note: flex bwd configs + + Kernel Constraints: + * BLOCK_N1 % BLOCK_M1 == 0 + * BLOCK_M2 % BLOCK_N2 == 0 + + Pattern 1 - Symmetric Pairing (M, N, N, M): + - Used in autotune configs + - block_m1=M, block_n1=N, block_m2=N, block_n2=M + - Only requires checking BLOCK_N % BLOCK_M == 0 + - Second constraint (BLOCK_M2 % BLOCK_N2) automatically satisfied + + Pattern 2 - Independent Parameters (M1, N1, M2, N2): + - Used in exhaustive search for maximum flexibility + - All four parameters can be set independently + - Requires checking both constraints + + """ + + block_m1: int + block_n1: int + block_m2: int + block_n2: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexDecodeConfig: + """ + Config class for flex decoding + """ + + block_n: int + num_stages: int + num_warps: int + + +# ROCm classes +@dataclasses.dataclass +class ROCmGemmConfig(GemmConfig): + """ + ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmConvConfig(ConvConfig): + """ + ROCm subclass for Conv, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexConfig(FlexConfig): + """ + ROCm subclass for FlexAttn, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexBwDConfig(FlexBwDConfig): + """ + ROCm subclass for FlexAttn backward, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexDecodeConfig(FlexDecodeConfig): + """ + ROCm subclass for FlexDecode, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +class BaseHeuristicSingleton(type): + """ + Thread-safe implementation of single to be used in the config heuristic subclasses + to ensure heavy __init__ calls are not repeatedly run + """ + + _instances: dict[type[Any], Any] = {} + _lock: Lock = Lock() + + def __call__( + cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any + ) -> BaseConfigHeuristic: + with cls._lock: + if cls not in cls._instances: + instance = super().__call__() + cls._instances[cls] = instance + return cls._instances[cls] + + +class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): + """ + Base class for mm_configs, device specific triton kernels config inherit from here + """ + + def __init__(self) -> None: + # Whether the heuristic is used for int8. Use this when the heuristic is int8 exclusive + # but prefer the preprocess_mm_configs argument when it's used for both + self.has_int8_tensor: bool = False + # Whether to scale configs at all + # TODO(coconutruben): remove this once mm_plus_mm and tests support scaling + self.should_scale_configs: bool = True + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform. The configs are as follows: + # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.mm_configs: list[BaseConfig] = [ + GemmConfig(32, 32, 16, 1, 2), + GemmConfig(32, 32, 128, 2, 4), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(64, 32, 128, 5, 4), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(64, 64, 128, 5, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(64, 128, 64, 3, 4), + GemmConfig(64, 128, 128, 4, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(128, 128, 32, 3, 4), + GemmConfig(128, 128, 64, 3, 4), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(128, 128, 128, 4, 8), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + GemmConfig( + BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m=group_m + ) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + for group_m in [8] + ] + + # these are only used in tuned_mm when AutoHeuristic is enabled + # the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned + # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 + # which saves compilation time (since less configs are autotuned) and potentially increase performance + # because the learned heuristic might predict a config that is not part mm_configs + self.extra_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 32, 16, 3, 2), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(64, 64, 128, 3, 4), + GemmConfig(128, 64, 32, 2, 2), + GemmConfig(128, 64, 64, 3, 8), + GemmConfig(128, 64, 128, 4, 8), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.int8_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(256, 128, 128, 3, 8), + ] + + self.mixed_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 128, 256, 3, 4), + GemmConfig(16, 128, 256, 5, 8), + ] + + self.persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(256, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.blackwell_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 4, 8), + GemmConfig(256, 128, 64, 3, 8), + GemmConfig(128, 256, 128, 2, 8), + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(256, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + ] + + self.blackwell_persistent_addmm_configs: list[BaseConfig] = [ + GemmConfig(256, 128, 64, 2, 4), + ] + + self.scaled_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 32, 3, 8), + GemmConfig(256, 128, 32, 3, 8), + GemmConfig(256, 64, 32, 4, 4), + GemmConfig(64, 256, 32, 4, 4), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 64, 32, 4, 4), + GemmConfig(64, 128, 32, 4, 4), + GemmConfig(128, 32, 32, 4, 4), + GemmConfig(64, 32, 32, 5, 2), + GemmConfig(256, 128, 128, 3, 8), + GemmConfig(256, 64, 128, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 64, 64, 4, 4), + GemmConfig(64, 128, 64, 4, 4), + GemmConfig(128, 32, 64, 4, 4), + GemmConfig(64, 32, 64, 5, 2), + GemmConfig(16, 32, 32, 2, 2), + GemmConfig(16, 64, 32, 2, 2), + GemmConfig(16, 128, 32, 2, 4), + GemmConfig(16, 256, 32, 2, 4), + GemmConfig(16, 32, 64, 2, 2), + GemmConfig(16, 64, 64, 2, 2), + GemmConfig(16, 128, 64, 2, 4), + GemmConfig(16, 256, 64, 2, 4), + GemmConfig(32, 32, 32, 2, 2), + GemmConfig(32, 64, 32, 2, 2), + GemmConfig(32, 128, 32, 2, 4), + GemmConfig(32, 256, 32, 2, 4), + GemmConfig(32, 32, 64, 2, 2), + GemmConfig(32, 64, 64, 2, 2), + GemmConfig(32, 128, 64, 2, 4), + GemmConfig(32, 256, 64, 2, 4), + GemmConfig(16, 32, 32, 3, 2), + GemmConfig(16, 64, 32, 3, 2), + GemmConfig(16, 128, 32, 3, 4), + GemmConfig(16, 256, 32, 3, 4), + GemmConfig(16, 32, 64, 3, 2), + GemmConfig(16, 64, 64, 3, 2), + GemmConfig(16, 128, 64, 3, 4), + GemmConfig(16, 256, 64, 3, 4), + GemmConfig(32, 32, 32, 3, 2), + GemmConfig(32, 64, 32, 3, 2), + GemmConfig(32, 128, 32, 3, 4), + GemmConfig(32, 256, 32, 3, 4), + GemmConfig(32, 32, 64, 3, 2), + GemmConfig(32, 64, 64, 3, 2), + GemmConfig(32, 128, 64, 3, 4), + GemmConfig(32, 256, 64, 3, 4), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 64, 32, 4, 2), + GemmConfig(16, 128, 32, 4, 4), + GemmConfig(16, 256, 32, 4, 4), + GemmConfig(16, 32, 64, 4, 2), + GemmConfig(16, 64, 64, 4, 2), + GemmConfig(16, 128, 64, 4, 4), + GemmConfig(16, 256, 64, 4, 4), + GemmConfig(32, 32, 32, 4, 2), + GemmConfig(32, 64, 32, 4, 2), + GemmConfig(32, 128, 32, 4, 4), + GemmConfig(32, 256, 32, 4, 4), + GemmConfig(32, 32, 64, 4, 2), + GemmConfig(32, 64, 64, 4, 2), + GemmConfig(32, 128, 64, 4, 4), + GemmConfig(32, 256, 64, 4, 4), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(16, 64, 32, 5, 2), + GemmConfig(16, 128, 32, 5, 4), + GemmConfig(16, 256, 32, 5, 4), + GemmConfig(16, 32, 64, 5, 2), + GemmConfig(16, 64, 64, 5, 2), + GemmConfig(16, 128, 64, 5, 4), + GemmConfig(16, 256, 64, 5, 4), + GemmConfig(32, 32, 32, 5, 2), + GemmConfig(32, 64, 32, 5, 2), + GemmConfig(32, 128, 32, 5, 4), + GemmConfig(32, 256, 32, 5, 4), + GemmConfig(32, 32, 64, 5, 2), + GemmConfig(32, 64, 64, 5, 2), + GemmConfig(32, 128, 64, 5, 4), + GemmConfig(32, 256, 64, 5, 4), + GemmConfig(16, 32, 32, 6, 2), + GemmConfig(16, 64, 32, 6, 2), + GemmConfig(16, 128, 32, 6, 4), + GemmConfig(16, 256, 32, 6, 4), + GemmConfig(16, 32, 64, 6, 2), + GemmConfig(16, 64, 64, 6, 2), + GemmConfig(16, 128, 64, 6, 4), + GemmConfig(16, 256, 64, 6, 4), + GemmConfig(32, 32, 32, 6, 2), + GemmConfig(32, 64, 32, 6, 2), + GemmConfig(32, 128, 32, 6, 4), + GemmConfig(32, 256, 32, 6, 4), + GemmConfig(32, 32, 64, 6, 2), + GemmConfig(32, 64, 64, 6, 2), + GemmConfig(32, 128, 64, 6, 4), + GemmConfig(32, 256, 64, 6, 4), + GemmConfig(64, 16, 256, 5, 4), + GemmConfig(64, 32, 256, 5, 4), + GemmConfig(64, 128, 128, 2, 4), + GemmConfig(64, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 2, 4), + GemmConfig(128, 256, 128, 4, 8), + GemmConfig(256, 128, 128, 2, 4), + GemmConfig(256, 128, 128, 2, 8), + ] + + self.scaled_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 4, 8), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 5, 4), + GemmConfig(128, 128, 128, 5, 8), + GemmConfig(128, 128, 128, 6, 8), + GemmConfig(128, 128, 64, 4, 8), + GemmConfig(64, 32, 256, 5, 4), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(64, 128, 256, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + ] + + # TODO: Unify with other gemm patterns, mm_plus_mm currently follows + # slightly different pattern than rest + self.mm_plus_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 32, 3, 8), + GemmConfig(64, 64, 32, 4, 16), + GemmConfig(64, 32, 32, 4, 8), + GemmConfig(32, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 1, 8), + GemmConfig(64, 64, 64, 1, 8), + GemmConfig(32, 32, 128, 1, 8), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(32, 32, 16, 1, 2), + ] + + self.conv_configs: list[BaseConfig] = [ + ConvConfig(64, 256, 16, 2, 4), + ConvConfig(256, 64, 16, 2, 4), + ConvConfig(1024, 16, 16, 1, 8), + ConvConfig(128, 128, 32, 2, 8), + ConvConfig(64, 64, 32, 2, 4), + ConvConfig(64, 256, 32, 2, 8), + ConvConfig(256, 64, 32, 2, 8), + ] + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(128, 64, 3, 4), + FlexConfig(128, 128, 3, 4), + FlexConfig(128, 128, 2, 8), + FlexConfig(128, 128, 1, 8), + FlexConfig(64, 128, 3, 4), + FlexConfig(64, 64, 3, 4), + ] + + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK_M, BLOCK_N, BLOCK_N, BLOCK_M, s, w) + for BLOCK_M in [32, 64] + for BLOCK_N in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK_M >= 128 or BLOCK_N >= 128 else [4]) + if BLOCK_N % BLOCK_M == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(64, 3, 2), + FlexDecodeConfig(32, 3, 2), + FlexDecodeConfig(128, 3, 2), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2, num_stages, num_warps) + for BLOCK_M1 in [16, 32, 64, 128] + for BLOCK_N1 in [16, 32, 64, 128] + for BLOCK_M2 in [16, 32, 64, 128] + for BLOCK_N2 in [16, 32, 64, 128] + for num_stages in [1, 3, 4] + for num_warps in [2, 4, 8] + if BLOCK_N1 % BLOCK_M1 == 0 + and BLOCK_M2 % BLOCK_N2 == 0 # kernel static assertions + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(block_n, num_stages, num_warps) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[Optional[int], ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Construct key for finding duplicate configs + key: tuple[Optional[int], ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.hint_override, + num_warps, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "hint_override": conf.hint_override, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(conf.num_stages, num_warps, **kwargs) + + def _scale_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + scale: float, + has_int8_tensor: bool, + exclude: Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool], + hint_override: Optional[int] = None, + ) -> list[BaseConfig]: + """ + Scales and filters matrix multiplication configs based on input size. + """ + if not self.should_scale_configs: + return configs + from ..runtime.runtime_utils import next_power_of_2 + + min_block_size = 16 + min_block_size_k = 32 if (has_int8_tensor or self.has_int8_tensor) else 16 + + scaled_configs = [] + for hint_override in [None] + config.multi_kernel_hints: + m_hint = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + hint_override=hint_override, + ) + ), + min_block_size, + ) + n_hint = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + hint_override=hint_override, + ) + ), + min_block_size, + ) + k_hint = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + hint_override=hint_override, + ) + ), + min_block_size_k, + ) + + for c in configs: + scaled_config = dataclasses.replace( + c, + block_m=max(min(int(c.block_m * scale), m_hint), min_block_size), + block_n=max(min(int(c.block_n * scale), n_hint), min_block_size), + block_k=max(min(int(c.block_k * scale), k_hint), min_block_size_k), + hint_override=hint_override, + ) + + if not exclude( + scaled_config.block_m, scaled_config.block_n, scaled_config.block_k + ): + scaled_configs.append(scaled_config) + + return scaled_configs + + def _get_exceeding_shared_memory_checker( + self, + ) -> Optional[Callable[[BaseConfig, int], bool]]: + """ + Returns a function that checks whether a given configuration exceeds the available shared memory for the device. + If the device does not report available shared memory, returns None. + """ + + try: + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + if hasattr(props, "shared_memory_per_block_optin"): # for NVidia GPUs + sm_available = int(props.shared_memory_per_block_optin) + elif hasattr(props, "shared_memory_per_block"): # for ROCm + sm_available = int(props.shared_memory_per_block) + else: + return None + + except Exception: + # If CUDA is not available or properties cannot be queried, return None + return None + + # TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation. + def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool: + shared_mem_accum = dtype_size * ( + gemm_config.block_m * gemm_config.block_k + + gemm_config.block_n * gemm_config.block_k + ) + return shared_mem_accum * gemm_config.num_stages > sm_available + + return exceeds + + def _prune_exceeding_max_shared_mem_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + if dtype_size <= 0: + return configs + + is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker() + if is_exceeding_shared_memory is None: + return configs + + return [c for c in configs if not is_exceeding_shared_memory(c, dtype_size)] + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker() + + pruned_configs = [] + for gemm_config in configs: + # Will use more shared memory than available + if is_exceeding_shared_memory and is_exceeding_shared_memory( + gemm_config, dtype_size + ): + continue + + NUM_REG = 255 + acc_regs = math.ceil( + gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32) + ) + # Lower bound for register spillage, if exceeds the kernel will certainly spill + if acc_regs > NUM_REG: + continue + + pruned_configs.append(gemm_config) + + return pruned_configs + + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter configs based on specific requirements. + Subclasses can override this to implement custom filtering logic. + """ + return configs + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, + dtype_size: int = 0, + op_name: str = "mm", # For preprocessing overrides e.g. on CPU + ) -> Generator[TritonConfig, None, None]: + configs = self._filter_configs(configs) + scaled_configs = self._scale_mm_configs( + m, n, k, configs, scale, has_int8_tensor, exclude + ) + + # Filter out configs that require more shared memory than is available. + if config.max_autotune_prune_choices_based_on_shared_mem: + scaled_configs = self._prune_exceeding_max_shared_mem_configs( + scaled_configs, dtype_size + ) + + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + assert dtype_size > 0, "dtype_size must be provided for exhaustive search" + scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size) + return self._finalize_mm_configs(scaled_configs) + + def triton_config( + self, num_stages: int, num_warps: int, **kwargs: Any + ) -> TritonConfig: + from triton import Config as TritonConfig # type: ignore[attr-defined] + + return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps) + + def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.mm_configs) + + def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial( + self.preprocess_mm_configs, configs=self.conv_configs, op_name="conv" + ) + + # Flex attn helpers + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(block_n=64, num_stages=1, num_warps=2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class CPUConfigHeuristic(BaseConfigHeuristic): + """ + CPU-specific config heuristic with CPU-specific optimizations. + """ + + def _get_cpu_exclude_function( + self, method: str = "bmm" + ) -> Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool]: + """ + Get CPU-specific exclude function based on method type. + Returns a function that can be used as exclude condition. + Moved from mm_common._is_large_block_for_cpu and refactored to return a function. + """ + if method in ("conv"): + + def exclude_conv( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 256 or n > 256 or k > 256: + return True + return m * n * k > 2**17 + + return exclude_conv + elif method in ("mm", "addmm", "int_mm"): + + def exclude_mm( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + return m * n > 2**13 + + return exclude_mm + else: # Default to bmm implementation for unknown methods + + def exclude_bmm( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + if m > 128 or n > 128 or k > 128: + return True + return m * n > 2**12 + + return exclude_bmm + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, + dtype_size: int = 0, + op_name: str = "mm", # For preprocessing overrides e.g. on CPU + ) -> Generator[TritonConfig, None, None]: + """ + CPU-specific preprocessing that applies CPU-specific scaling (0.5) and exclusion logic. + """ + # Get CPU-specific exclude function based on operation type + cpu_exclude_fn = self._get_cpu_exclude_function(op_name) + + # Apply CPU-specific scaling (0.5) and exclusion logic + return super().preprocess_mm_configs( + m, + n, + k, + configs=configs, + has_int8_tensor=has_int8_tensor, + scale=0.5, + exclude=cpu_exclude_fn, + dtype_size=dtype_size, + op_name=op_name, + ) + + +class CUDAConfigHeuristic(BaseConfigHeuristic): + """ + Child class for CUDA device specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + self.sm_120_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 2, 4), + (torch.float32, 128): FlexConfig(128, 32, 2, 4), + (torch.float32, 256): FlexConfig(64, 16, 2, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 2, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 2, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 2, 4), + (torch.float16, 64): FlexConfig(128, 64, 2, 4), + (torch.float16, 128): FlexConfig(128, 64, 2, 8), + (torch.float16, 256): FlexConfig(32, 64, 2, 4), + } + + self.sm_100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 192): FlexConfig(32, 64, 2, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 192): FlexConfig(128, 128, 1, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 192): FlexConfig(128, 128, 1, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.h100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.a100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(128, 32, 3, 4), + (torch.float32, 256): FlexConfig(64, 16, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 3, 4), + (torch.float16, 64): FlexConfig(128, 64, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(32, 64, 3, 4), + } + + # Overwriting the configs omitting BLOCK_N of size 128 that cause ULFs + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK_M, BLOCK_N, BLOCK_N, BLOCK_M, s, 4) + for BLOCK_M in [32, 64] + for BLOCK_N in [32, 64] + for s in [1, 3, 4, 5] # num_stages + if BLOCK_N % BLOCK_M == 0 + ] + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(64, 64, 3, 4) + if capability >= (12, 0): + default_config = self.sm_120_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (10, 0): + default_config = self.sm_100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability == (9, 0): + default_config = self.h100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (8, 0): + default_config = self.a100_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + major, minor = capability + if dtype == torch.float32: + capability_class = "float32" + elif major == 12: + capability_class = "sm12x" + elif major >= 10: + capability_class = "sm10x" + elif capability == (9, 0): + capability_class = "sm90" + elif major >= 8: + capability_class = "sm8x" + else: + capability_class = "baseline" + + # fmt: off + config_map = { + "float32": lambda h: FlexBwDConfig(16, 16, 16, 16, 1, 4), + "baseline": lambda h: FlexBwDConfig(16, 16, 16, 16, 1, 4), + "sm90": lambda h: ( + FlexBwDConfig(64, 64, 64, 64, 3, 4) if h < 64 else + FlexBwDConfig(64, 128, 128, 64, 3, 8) if h <= 128 else + FlexBwDConfig(64, 64, 64, 64, 2, 4) + ), + "sm10x": lambda h: ( + FlexBwDConfig(64, 128, 128, 64, 3, 4) if h <= 128 else + FlexBwDConfig(64, 64, 64, 64, 1, 8) if h <= 192 else + FlexBwDConfig(64, 64, 64, 64, 1, 4) + ), + "sm8x": lambda h: ( + FlexBwDConfig(32, 128, 128, 32, 3, 4) + if h < 64 + else FlexBwDConfig( + 64, 64, 64, 64, 3 if minor == 6 and h == 128 else 2, 4 + ) + ), + "sm12x": lambda h: ( + FlexBwDConfig(32, 128, 128, 32, 3, 4) + if h < 64 + else FlexBwDConfig( + 64, 64, 64, 64, 3 if minor == 6 and h == 128 else 2, 4 + ) + ), + } + # fmt: on + + if head_dim <= 256: + default_config = config_map[capability_class](head_dim) + else: + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + capability = torch.cuda.get_device_capability() + + default_config = FlexDecodeConfig(64, 1, 2) + + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + if capability in [(9, 0), (10, 0), (10, 3)]: # sm_90, sm_100, sm_103 + if head_dim > 128 and dtype == torch.float32: + default_config = FlexDecodeConfig(64, 1, 2) + else: + default_config = FlexDecodeConfig(64, 3, 2) + else: + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class ROCmConfigHeuristic(BaseConfigHeuristic): + """ + Child class for ROCm specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.default_num_stages = get_backend_num_stages() + + self.mm_configs: list[BaseConfig] = [ + ROCmGemmConfig( + 16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig( + 32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + ROCmGemmConfig( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + group_m=group_m, + matrix_instr_nonkdim=matrix_instr_nonkdim, + waves_per_eu=waves_per_eu, + kpack=kpack, + ) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, self.default_num_stages] + for num_warps in [4, 8] + for group_m in [4, 8, 16] + for matrix_instr_nonkdim in [0, 16] + for waves_per_eu in [0, 2] + for kpack in [2] + ] + + self.default_flex_config = { + (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4), + (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 1, 8), + (torch.float16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 256): ROCmFlexConfig(32, 64, 1, 4), + } + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w) + for BLOCK1 in [16, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for w in [4, 8] + ] + + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + ROCmFlexBwDConfig(BLOCK1, BLOCK2, BLOCK2, BLOCK1, 1, w, mfma) + for BLOCK1 in [16, 32, 64] + for BLOCK2 in [32, 64, 128] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for mfma in [0, 16] + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(32, 1, 4), + ROCmFlexDecodeConfig(64, 1, 4), + ROCmFlexDecodeConfig(128, 1, 4), + ROCmFlexDecodeConfig(32, 1, 8), + ROCmFlexDecodeConfig(64, 1, 8), + ROCmFlexDecodeConfig(128, 1, 8), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + ROCmFlexBwDConfig( + BLOCK_M1, + BLOCK_N1, + BLOCK_M2, + BLOCK_N2, + num_stages, + num_warps, + mfma, + wpeu, + ) + for BLOCK_M1 in [16, 32, 64, 128] + for BLOCK_N1 in [16, 32, 64, 128] + for BLOCK_M2 in [16, 32, 64, 128] + for BLOCK_N2 in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + if BLOCK_N1 % BLOCK_M1 == 0 + and BLOCK_M2 % BLOCK_N2 == 0 # kernel static assertions + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + # these cause AMD compile to crash + pruned_configs = [ + c + for c in configs + if not ( + getattr(c, "matrix_instr_nonkdim", 0) == 2 + and getattr(c, "kpack", 0) == 2 + ) + ] + return pruned_configs + + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + ROCm specific filtering + """ + for c in configs: + c.num_stages = self.default_num_stages + return super()._filter_configs(configs) + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Defaults for AMD triton backend kern args if not set + matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu = getattr(conf, "waves_per_eu", 0) + kpack = getattr(conf, "kpack", 2) + + if matrix_instr_nonkdim != 0 and ( + conf.block_m % matrix_instr_nonkdim != 0 + or conf.block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.num_warps, + waves_per_eu, + matrix_instr_nonkdim, + kpack, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + # AMD GPU crashes if group_m = 0 + if group_m is not None and group_m <= 0: + group_m = 8 + if group_m is not None: + key += (group_m,) + + if waves_per_eu != 0: + waves_per_eu = int(8 // conf.num_warps) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": conf.num_warps, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(128, 64, 1, 8) + default_config = self.default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = ROCmFlexConfig(32, 16, 1, 4) + else: + default_config = ROCmFlexConfig(64, 32, 1, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = ROCmFlexBwDConfig(16, 16, 16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = ROCmFlexBwDConfig(64, 64, 64, 64, 1, 4) + elif head_dim == 128: + default_config = ROCmFlexBwDConfig(64, 128, 128, 64, 1, 8) + else: + default_config = ROCmFlexBwDConfig(64, 64, 64, 64, 1, 4) + else: + default_config = ROCmFlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = ROCmFlexDecodeConfig(64, 1, 4) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class XPUConfigHeuristic(BaseConfigHeuristic): + """ + Placeholder child class for Intel GPU specific overrides. + """ + + def __init__(self) -> None: + super().__init__() + self.xpu_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 1, 16), + (torch.float32, 128): FlexConfig(128, 32, 1, 16), + (torch.float32, 256): FlexConfig(64, 16, 1, 8), + (torch.bfloat16, 64): FlexConfig(128, 64, 1, 16), + (torch.bfloat16, 128): FlexConfig(128, 64, 1, 16), + (torch.bfloat16, 256): FlexConfig(32, 64, 1, 4), + (torch.float16, 64): FlexConfig(128, 64, 1, 16), + (torch.float16, 128): FlexConfig(128, 64, 1, 16), + (torch.float16, 256): FlexConfig(32, 64, 1, 4), + } + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(32, 16, 2, 4), + FlexConfig(128, 64, 2, 16), + FlexConfig(128, 64, 2, 8), + FlexConfig(128, 32, 2, 16), + FlexConfig(128, 32, 2, 8), + ] + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [] + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [] + + if not bool(os.getenv("CI")): + self.flex_attn_bwd_autotune_configs += [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK1, BLOCK2, BLOCK2, BLOCK1, s, w) + for BLOCK1 in [32, 64] + for BLOCK2 in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + if BLOCK2 % BLOCK1 == 0 + ] + self.flex_decode_autotune_configs += [ + FlexDecodeConfig(32, 1, 2), + FlexDecodeConfig(32, 1, 1), + FlexDecodeConfig(32, 2, 2), + FlexDecodeConfig(32, 2, 1), + FlexDecodeConfig(64, 1, 2), + FlexDecodeConfig(64, 1, 1), + FlexDecodeConfig(64, 2, 2), + FlexDecodeConfig(64, 2, 1), + ] + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 1, 8) + else: + default_config = FlexConfig(128, 64, 1, 16) + default_config = self.xpu_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 1, 4) + else: + default_config = FlexConfig(64, 32, 1, 8) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = FlexBwDConfig(64, 64, 64, 64, 1, 8) + elif head_dim == 128: + default_config = FlexBwDConfig(64, 128, 64, 128, 1, 8) + else: + default_config = FlexBwDConfig(64, 64, 64, 64, 1, 8) + else: # modest hardware or extremely large head_dim + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + return configs + + +class MTIAConfigHeuristic(BaseConfigHeuristic): + """ + Placeholder child class for MTIA specific overrides. + """ + + +# Template-specific mixin classes +class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): + """ + Mixin class that converts config lists to template kwargs. + This handles the logic that was previously in choices.get_mm_configs. + + This mixin expects to be used with BaseConfigHeuristic or its subclasses. + """ + + # Type annotations to ensure the mixin works with BaseConfigHeuristic + get_mm_configs: Callable[[], partial[Generator[TritonConfig, None, None]]] + get_exhaustive_mm_configs: Callable[ + [], partial[Generator[TritonConfig, None, None]] + ] + _filter_configs: Callable[[list[BaseConfig]], list[BaseConfig]] + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + assert isinstance(kernel_inputs, MMKernelInputs) + m, n, k = kernel_inputs.mnk_symbolic() + # Calculate allow_tf32 + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0) + ) + + return { + "ALLOW_TF32": allow_tf32, + } + + def _valid(self, kernel_inputs: KernelInputs) -> bool: + return True + + def _get_config_generator( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + """ + Get the appropriate config generator based on search space. + Can be overridden by subclasses for template-specific behavior. + """ + # Handle exhaustive search case + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return self.get_exhaustive_mm_configs() + else: + return self.get_mm_configs() + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Convert config lists to template kwargs. + This replaces the logic from choices.get_mm_configs and inlines mm_options. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + input_nodes = kernel_inputs.nodes() + if len(input_nodes) < 2: + raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}") + if not self._valid(kernel_inputs): + return + + # Extract M, N, K from kernel_inputs + m, n, k = kernel_inputs.mnk_symbolic() + + # Extract dtype and device_type from kernel_inputs + dtype = kernel_inputs.dtype() + + # Get the appropriate config generator + configs = self._get_config_generator() + + # Generate and process configs + for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name): + template_kwargs = self._convert_config_to_template_kwargs( + c, + m, + n, + k, + kernel_inputs.out_dtype(), + ) + yield template_kwargs + + def _convert_config_to_template_kwargs( + self, + triton_config: TritonConfig, + m: sympy.Integer, + n: sympy.Integer, + k: sympy.Integer, + out_dtype: torch.dtype, + ) -> dict[str, Any]: + """ + Convert triton config to template kwargs. + Moved from mm_common.mm_options. + """ + # Calculate EVEN_K symbolic + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(k, triton_config.kwargs["BLOCK_K"]) + == triton_config.kwargs["BLOCK_K"] + ) + + # Build options dict + + options_dict = dict( + EVEN_K=even_k_symbolic, + USE_FAST_ACCUM=False, # Option for _scaled_mm + ACC_TYPE=self._get_acc_type(out_dtype), + num_stages=triton_config.num_stages, + num_warps=triton_config.num_warps, + **triton_config.kwargs, + ) + + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in triton_config.kwargs: + group_m = triton_config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + + def _get_acc_type(self, dtype: torch.dtype) -> str: + """ + Get accumulator type for the given dtype. + Moved from mm_common.acc_type. + """ + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +# INT8 specific mixin to filter correctly +class INT8MMTemplateConfigMixin(MMTemplateConfigMixin): + """ + Ensure that we feed in has_int8_tensor=True + """ + + def __init__(self) -> None: + super().__init__() + self.has_int8_tensor = True + + +# MMPlusMM specific mixin to avoid running _scale_mm_configs +class MMPlusMMTemplateConfigMixin(MMTemplateConfigMixin): + """ + Ensure that _should_scale_configs is False + """ + + # TODO(coconutruben): remove this once all tests work + # with proper scaling on mm_plus_mm + def __init__(self) -> None: + super().__init__() + self.should_scale_configs = False + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs" + m, n, k = kernel_inputs.mnk_symbolic() + for kwargs in super()._get_template_configs_impl(kernel_inputs, op_name): + # Apply BLOCK_K constraint specific to mm_plus_mm + # see https://github.com/triton-lang/triton/issues/1298 + # BLOCK_K = K causes llvm error + if V.graph.sizevars.statically_known_lt(kwargs.get("BLOCK_K", k), k): + yield kwargs + + +class TMAWorkspaceMixin(MMTemplateConfigMixin): + """ + Small mixin to ensure that the workspace arg is correct for TMA + and TMA specific filtering can happen. + """ + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + kwargs["workspace_arg"] = get_tma_workspace_arg( + num_tma_descriptors=2, + device=kernel_inputs.device(), + ) + return kwargs + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + TMA specific filtering, as num_warps=2 not safe for TMA + """ + configs = [c for c in configs if c.num_warps != 2] + return super()._filter_configs(configs) + + +# TMA-specific mixin for TMA templates +class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin): + """ + TMA-specific mixin that uses persistent configs and adds TMA options. + This inherits from MMTemplateConfigMixin and overrides config generation. + """ + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate TMA template configs by calling super and adding TMA-specific options. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + "TMATemplateConfigMixin requires MMKernelInputs" + ) + mat1, mat2 = kernel_inputs.mat1mat2() + tma_opts = { + "A_ROW_MAJOR": not mat1.layout.is_transposed(), + "B_ROW_MAJOR": not mat2.layout.is_transposed(), + "NUM_SMS": get_num_sms(), + "TMA_SIZE": TMA_DESCRIPTOR_SIZE, + "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), + "tma_store": config.triton.enable_template_tma_store, + "transpose_discontiguous_tensor_descriptors_override": True, + } + # Get base template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + yield {**template_kwargs, **tma_opts} + + +# TMA mixins for Blackwell templates +class BlackwellTMATemplateConfigMixin(TMATemplateConfigMixin): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate TMA template configs by calling super and adding TMA-specific options. + """ + base_ops = { + "NUM_SMS": get_num_sms(), + # TODO: Consider making this tunable. + "FLATTEN": True, + } + # Get base template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + # Some Triton versions requires num_warps >= 4 for WS + # to avoid compilation issues. Triton disables WS if num_warps < 4 + # or num_stages < 2. Similar issues have been seen with num_stages=1 + ws = ( + template_kwargs["num_warps"] >= 4 and template_kwargs["num_stages"] >= 2 + ) + yield { + **template_kwargs, + **base_ops, + "WARP_SPECIALIZE": ws, + "EPILOGUE_SUBTILE": config.triton.enable_epilogue_subtiling, + } + + +# Scaled MM-specific mixin for scaled MM templates +class BaseScaledMMConfigMixin(MMTemplateConfigMixin): + """ + This is a base that handles the common case for ScaledMM + + The TMA and non-TMA should build on top of this + """ + + def adjust_kernel_inputs( + self, kernel_inputs: KernelInputs, op_name: str + ) -> KernelInputs: + """ + for scaled_mm, we need to unsqueeze scale tensors, and bias + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + "Expect MMKernelInputs for scaled MM" + ) + inputs = super().adjust_kernel_inputs(kernel_inputs, op_name) + nodes = inputs.nodes() + mat_a, mat_b, scale_a, scale_b, *bias = nodes + bias = bias[0] if bias else None + # Prepare triton input nodes and create kernel_inputs at the top + from ..lowering import lowerings as L + + aten = torch.ops.aten + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + bias = L[aten.unsqueeze](bias, 0) + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + nodes = [mat_a, mat_b, scale_a, scale_b] + if bias: + nodes.append(bias) + return MMKernelInputs( + nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx + ) + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate scaled MM template configs with scaled MM-specific options. + Handles the remaining logic from mm_common, including assertions. + """ + kernel_inputs = self.adjust_kernel_inputs(kernel_inputs, op_name) + input_nodes = kernel_inputs.nodes() + # Initial assertion from mm_common.scaled_mm_options + assert len(input_nodes) >= 4, ( + f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}" + ) + + # Extract scale tensors (typically scale_a and scale_b are input_nodes[2] and input_nodes[3]) + scale_a = input_nodes[2] + scale_b = input_nodes[3] + + # Scale compatibility assertion from mm_common.scaled_mm_options + def are_compatible_scales(size_a: Any, size_b: Any) -> bool: + # Same sized scales are compatible + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + + if not self._valid(kernel_inputs): + return + + # Get base template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, op_name + ): + # Add scaled MM-specific options (moved from mm_common.scaled_mm_options) + # Override accumulator type for scaled MM + template_kwargs["ACC_TYPE"] = "tl.float32" + + yield template_kwargs + + +class ScaledMMConfigMixin(BaseScaledMMConfigMixin): + """Mixing for scaled mm with the regular mm template""" + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + from ..kernel.mm_common import scale_mm_epilogue + + return { + **kwargs, + "suffix_args": kernel_inputs.count - 2, + "epilogue_fn": scale_mm_epilogue(), + "epilogue_fn_hash": "scale_mm_epilogue", + } + + def _valid(self, kernel_inputs: KernelInputs) -> bool: + assert isinstance(kernel_inputs, MMKernelInputs), ( + "Expect MMKernelInputs for ScaledMMConfigMixin" + ) + _, _, k = kernel_inputs.mnk_symbolic() + if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)): + # Triton crashes however uncommon for real workloads + return False + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)): + return False + return True + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter out bad configs for specific hardware. + On AMD MI350X (GFX 9.5+), skip configs with BLOCK_K<=64 due to lack of corresponding MFMA instructions. + """ + + def should_skip_mi350x_config(config: BaseConfig) -> bool: + """Skip config if BLOCK_K<=64 on MI350X (GFX 9.5+)""" + try: + return ( + config.block_k <= 64 + and torch.version.hip is not None + and torch.cuda.get_device_capability() >= (9, 5) + ) + except RuntimeError: + # If no HIP GPUs are available, we can't check device capability + # so we don't skip any configs + return False + + filtered_configs = [c for c in configs if not should_skip_mi350x_config(c)] + return super()._filter_configs(filtered_configs) + + +# Scaled TMA-specific mixin for scaled MM templates with TMA +class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): + """ + Scaled TMA-specific mixin that extends BaseScaledMMConfigMixin with TMA functionality. + This is for scaled MM templates that use device TMA. + This inherits from BaseScaledMMConfigMixin and adds TMA-specific options. + """ + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + TMA specific filtering: + - num_warps=2 not safe for TMA + - block_k >= 32 required for TMA (requires inner-most dimension >= 32) + """ + configs = [c for c in configs if c.num_warps != 2 and c.block_k >= 32] + return super()._filter_configs(configs) + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate scaled TMA template configs with both scaled MM and TMA-specific options. + """ + # Get base scaled MM template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + # Add TMA-specific options for device TMA scaled MM + template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + template_kwargs["NUM_SMS"] = get_num_sms() + template_kwargs["TMA_EXPERIMENTAL_API"] = not has_triton_stable_tma_api() + + yield template_kwargs + + +# Scaled Blackwell TMA-specific mixin for scaled MM templates with TMA +class ScaledBlackwellTMAConfigMixin( + BlackwellTMATemplateConfigMixin, ScaledMMConfigMixin +): + """ + Scaled Blackwell TMA-specific mixin that extends ScaledMMConfigMixin with TMA functionality. + This is for scaled MM templates that use device TMA on Blackwell. + This inherits from ScaledMMConfigMixin, which inherits the scale_mm_epilogue, and adds TMA-specific options. + """ + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Warp specialization-specific filtering (BlackwellTMATemplateConfigMixin) + (compilation issues occur in some versions of Triton) + - num_warps < 4 unsafe for warpspec + - num_stages < 2 unsafe for warpspec + + TMA-specific filtering: + - block_k >= 32 required for TMA (requires inner-most dimension >= 32) + """ + configs = [c for c in configs if c.block_k >= 32] + return super()._filter_configs(configs) + + +# Template-specific heuristic classes using multiple inheritance + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +@register_template_heuristic( + bmm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + """Standard MM template heuristic for CUDA""" + + +@register_template_heuristic( + mm_template.uid, "cuda", register=torch.version.hip is None, op_name="addmm" +) +@register_template_heuristic( + bmm_template.uid, "cuda", register=torch.version.hip is None, op_name="baddbmm" +) +class CUDAAddMMTemplateConfigHeuristic(AddMMConfigMixin, CUDAMMTemplateConfigHeuristic): + """Addmm specific mixin for CUDA""" + + +# TODO(coconutruben): deprecate once autoheuristic is deprecated +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="mm-ah", +) +class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + """Standard MM template heuristic for CUDA using the extra mm configs only (for autoheuristic)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.extra_mm_configs + self.exhaustive_configs = self.extra_mm_configs + + +@register_template_heuristic( + persistent_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDAPersistentTMATemplateConfigHeuristic( + TMATemplateConfigMixin, CUDAConfigHeuristic +): + """Persistent TMA template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use persistent_mm_configs + self.mm_configs = self.persistent_mm_configs + + +@register_template_heuristic( + blackwell_ws_persistent_device_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDABlackwellPersistentTMATemplateConfigHeuristic( + BlackwellTMATemplateConfigMixin, CUDAConfigHeuristic +): + """Blackwell Persistent TMA template""" + + def __init__(self) -> None: + super().__init__() + self.mm_configs = self.blackwell_persistent_mm_configs + + +@register_template_heuristic( + persistent_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="addmm", +) +class CUDAAddmmPersistentTMATemplateConfigHeuristic( + AddMMConfigMixin, CUDAPersistentTMATemplateConfigHeuristic +): + """Addmm specific mixin for CUDA""" + + +@register_template_heuristic( + blackwell_ws_persistent_device_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="addmm", +) +class CUDABlackwellAddmmPersistentTMATemplateConfigHeuristic( + AddMMConfigMixin, CUDABlackwellPersistentTMATemplateConfigHeuristic +): + """Addmm extension for DataCenter Blackwell Templates""" + + def __init__(self) -> None: + super().__init__() + # NOTE: to ensure that we pass tests, addmm needs a small config + self.mm_configs = ( + self.blackwell_persistent_mm_configs + + self.blackwell_persistent_addmm_configs + ) + + +@register_template_heuristic( + mm_template.uid, "cuda", register=torch.version.hip is None, op_name="scaled_mm" +) +class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeuristic): + """Scaled MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + configs = [c for c in configs if c.block_k >= 32] + return super()._filter_configs(configs) + + +@register_template_heuristic( + scaled_mm_device_tma_epilogue_scaling_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="scaled_mm", +) +class CUDAScaledTMAEpilogueScalingTemplateConfigHeuristic( + ScaledTMAConfigMixin, CUDAConfigHeuristic +): + """Scaled TMA template heuristic for CUDA: epilogue scaling variants (TensorWise, RowWise)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_persistent_mm_configs for TMA + self.mm_configs = self.scaled_persistent_mm_configs + + +@register_template_heuristic( + scaled_mm_device_tma_main_loop_scaling_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="scaled_mm", +) +class CUDAScaledTMAMainLoopScalingTemplateConfigHeuristic( + ScaledTMAConfigMixin, CUDAConfigHeuristic +): + """ + Scaled TMA template heuristic for CUDA: + main loop scaling variants (BlockWise1x128, BlockWise1x32, BlockWise1x16, BlockWise128x128) + """ + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_persistent_mm_configs for TMA + self.mm_configs = self.scaled_persistent_mm_configs + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate main loop scaling kernel inputs. + """ + mat_a, mat_b, scale_a, scale_b = kernel_inputs._input_nodes + scale_a_size, scale_b_size = scale_a.get_size(), scale_b.get_size() + + scale_option_a, scale_option_b = get_scaling_options( + mat_a, mat_b, scale_a_size, scale_b_size + ) + tile_size_a = get_tile_size(scale_option_a) + tile_size_b = get_tile_size(scale_option_b) + + # Get base scaled MM template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + # Add scaling-specific options for main loop scaling variants + + # Inductor templates require compile-time constants passed in as tl.constexpr values. + # In cases in which the block size (BLOCK_*) is smaller than the tile size (128, 32, 16), + # scales must be broadcasted to BLOCK_* (rather than to a tile_sizextile_size chunk). + + template_kwargs["TILE_SIZE_A"] = tile_size_a + template_kwargs["TILE_SIZE_B"] = tile_size_b + + template_kwargs["MIN_BLOCK_TILE_AM"] = min( + template_kwargs["BLOCK_M"], tile_size_a + ) + template_kwargs["MIN_BLOCK_TILE_AK"] = min( + template_kwargs["BLOCK_K"], tile_size_a + ) + template_kwargs["MIN_BLOCK_TILE_BK"] = min( + template_kwargs["BLOCK_K"], tile_size_b + ) + template_kwargs["MIN_BLOCK_TILE_BN"] = min( + template_kwargs["BLOCK_N"], tile_size_b + ) + + yield template_kwargs + + +@register_template_heuristic( + blackwell_ws_persistent_device_tma_mm_template.uid, # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin + "cuda", + register=torch.version.hip is None, + op_name="scaled_mm", +) +class CUDAScaledBlackwellTMATemplateConfigHeuristic( + ScaledBlackwellTMAConfigMixin, CUDAConfigHeuristic +): + """Scaled Blackwell TMA template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_persistent_mm_configs for TMA + # TODO: Tune scaled_persistent_mm_configs for Blackwell + self.mm_configs = self.scaled_persistent_mm_configs + + +@register_template_heuristic( + mm_plus_mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDAMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, CUDAConfigHeuristic +): + """MM Plus MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="int_mm", +) +class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeuristic): + """Int8 MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +# ROCm template-specific classes + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +@register_template_heuristic( + bmm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class ROCmMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic): + """Standard MM template heuristic for ROCm""" + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + mm_template.uid, "cuda", register=torch.version.hip is not None, op_name="addmm" +) +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + bmm_template.uid, "cuda", register=torch.version.hip is not None, op_name="baddbmm" +) +class ROCmAddMMTemplateConfigHeuristic(AddMMConfigMixin, ROCmMMTemplateConfigHeuristic): + """Addmm specific mixin for ROCm""" + + +# TODO(coconutruben): deprecate once autoheuristic is deprecated +@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None) +class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic): + """Standard MM template heuristic for ROCm using the extra mm configs only (for autoheuristic)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.extra_mm_configs + self.exhaustive_configs = self.extra_mm_configs + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="scaled_mm", +) +class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeuristic): + """Scaled MM template heuristic for ROCm (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="int_mm", +) +class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeuristic): + """Int8 MM template heuristic for ROCm""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic( + mm_plus_mm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class ROCmMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, ROCmConfigHeuristic +): + """MM Plus MM template heuristic for ROCm""" + + def __init__(self) -> None: + super().__init__() + # self.default_num_stages is used to make sure all configs have that in ROCm land + # for mm_plus_mm, we actually just want stages = 1, as pipelining brings no benefits + self.default_num_stages = 1 + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# CPU template-specific classes + + +@register_template_heuristic(mm_template.uid, "cpu") +@register_template_heuristic(bmm_template.uid, "cpu") +class CPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic): + """Standard MM template heuristic for CPU""" + + +@register_template_heuristic(mm_template.uid, "cpu", op_name="addmm") +@register_template_heuristic(bmm_template.uid, "cpu", op_name="baddbmm") +class CPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, CPUMMTemplateConfigHeuristic): + """Addmm specific mixin for CPU""" + + +@register_template_heuristic(mm_template.uid, "cpu", op_name="scaled_mm") +class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic): + """Scaled MM template heuristic for CPU (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic(mm_template.uid, "cpu", op_name="int_mm") +class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuristic): + """Int8 MM template heuristic for CPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic(mm_plus_mm_template.uid, "cpu") +class CPUMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, CPUConfigHeuristic +): + """MM Plus MM template heuristic for CPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# XPU template-specific classes + + +@register_template_heuristic(mm_template.uid, "xpu") +@register_template_heuristic(bmm_template.uid, "xpu") +class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic): + """Standard MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + + # TODO(etaf): Design proper exhaustive search space for XPU. + self.exhaustive_configs = self.mm_configs + + +@register_template_heuristic(mm_template.uid, "xpu", op_name="addmm") +@register_template_heuristic(bmm_template.uid, "xpu", op_name="baddbmm") +class XPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, XPUMMTemplateConfigHeuristic): + """Addmm specific mixin for XPU""" + + +@register_template_heuristic( + persistent_tma_mm_template.uid, + "xpu", +) +class XPUPersistentTMATemplateConfigHeuristic( + TMATemplateConfigMixin, XPUConfigHeuristic +): + """Persistent TMA template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use persistent_mm_configs + self.mm_configs = self.persistent_mm_configs + + +@register_template_heuristic(persistent_tma_mm_template.uid, "xpu", op_name="addmm") +class XPUAddmmPersistentTMATemplateConfigHeuristic( + AddMMConfigMixin, XPUPersistentTMATemplateConfigHeuristic +): + """Addmm specific mixin for XPU""" + + +@register_template_heuristic(mm_template.uid, "xpu", op_name="scaled_mm") +class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic): + """Scaled MM template heuristic for XPU (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic(mm_template.uid, "xpu", op_name="int_mm") +class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuristic): + """Int8 MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic(mm_plus_mm_template.uid, "xpu") +class XPUMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, XPUConfigHeuristic +): + """MM Plus MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# MTIA template-specific classes + + +@register_template_heuristic(mm_template.uid, "mtia") +@register_template_heuristic(bmm_template.uid, "mtia") +class MTIAMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic): + """Standard MM template heuristic for MTIA""" + + +@register_template_heuristic(mm_template.uid, "mtia", op_name="addmm") +@register_template_heuristic(bmm_template.uid, "mtia", op_name="baddbmm") +class MTIAAddMMTemplateConfigHeuristic(AddMMConfigMixin, MTIAMMTemplateConfigHeuristic): + """Addmm specific mixin for MTIA""" + + +@register_template_heuristic(mm_template.uid, "mtia", op_name="scaled_mm") +class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeuristic): + """Scaled MM template heuristic for MTIA (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic(mm_template.uid, "mtia", op_name="int_mm") +class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeuristic): + """Int8 MM template heuristic for MTIA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic(mm_plus_mm_template.uid, "mtia") +class MTIAMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, MTIAConfigHeuristic +): + """MM Plus MM template heuristic for MTIA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton_addmm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton_addmm.py new file mode 100644 index 0000000000000000000000000000000000000000..a6643d1ce2a90de0f31ef07e6f20d689d9d101b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton_addmm.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from ..kernel.mm_common import addmm_epilogue +from .base import TemplateConfigHeuristics + + +if TYPE_CHECKING: + from ..kernel_inputs import KernelInputs + + +class AddMMConfigMixin(TemplateConfigHeuristics): + """ + Simple mixin to handle scalars for addmm like operators (addmm, baddbmm) + """ + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + assert op_name in [ + "addmm", + "baddbmm", + ], f"op_name={op_name} invalid for AddMMConfigMixin" + alpha = kernel_inputs.get_scalar("alpha") + beta = kernel_inputs.get_scalar("beta") + return { + **kwargs, + "epilogue_fn": addmm_epilogue(kernel_inputs.out_dtype(), alpha, beta), + "epilogue_fn_hash": str( + ["addmm_epilogue", kernel_inputs.out_dtype(), alpha, beta] + ), + "prefix_args": 1, + } diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bf6a3d54c298f9b51db9f7b9e32a2e0b4a951c0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/closure.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/closure.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421598d96b44732cb7613801f055f8050a174f6a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/closure.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/computation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/computation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c557750832abc2dbcbaf7a19005222ae2269009 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/computation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd64564949a64e3021a1b2c16b13ec92ab4b1602 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/debug.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b38733be07f110ebe0c13358de668937a023e35 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/debug.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/device_context.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/device_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcfa2d002866c24c2acb299e0d4cf1393496afbb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/device_context.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aec6bea4b5926499f5d0857040d5da9b897edb61 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97b6f45720d58c7f86dcd92247b87bfb06d1c3e8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/metrics.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6f096abb2c28f01cc1011a8ecf1925e53de6317 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/metrics.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfdd3a120fc313d0843d847daaf1d55eb354c39d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f04312869b9d04ca332712067d97b118bead57c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1c8441141ffa5c54ea33afa47022a2d0c597b6f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3202716174d8a54db35865f8d8498f91248e997a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b06faf9431327386d864d2e3069b1c79a1bdc32c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf58eb0ab8bd0df9c76dabb61946d6148d778088 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a99071667aed39dbd6aa09b27cf2910805fa072 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims_common/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims_common/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9fbd6341b6be54962efbba8d14f2347491bbfc2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims_common/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9e2f08fa1a9ca26b13e05551e5de0aac9d28d9c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e9df5d1c2215a63270bd14b9050c71e543f06ac Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22809cfd5dc25792d77070c269fc8d111a12eed0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py @@ -0,0 +1,15 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +__title__ = "packaging" +__summary__ = "Core utilities for Python packages" +__uri__ = "https://github.com/pypa/packaging" + +__version__ = "23.2" + +__author__ = "Donald Stufft and individual contributors" +__email__ = "donald@stufft.io" + +__license__ = "BSD-2-Clause or Apache-2.0" +__copyright__ = "2014 %s" % __author__ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33e04e73c00ada708f5639c4922118f5e763f816 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5daac6178f98b9a96c8eb72a6f3f7c0030ce045 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9077ecd5baf8ae73d05dbc10d965b6aa8741489a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..90a6465f9682c886363eea5327dac64bf623a6ff --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> "NegativeInfinityType": + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/version.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/version.py new file mode 100644 index 0000000000000000000000000000000000000000..5faab9bd0dcf28847960162b2b4f13a8a556ef20 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/version.py @@ -0,0 +1,563 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. +""" +.. testsetup:: + + from packaging.version import parse, Version +""" + +import itertools +import re +from typing import Any, Callable, NamedTuple, Optional, SupportsInt, Tuple, Union + +from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType + +__all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"] + +LocalType = Tuple[Union[int, str], ...] + +CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]] +CmpLocalType = Union[ + NegativeInfinityType, + Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...], +] +CmpKey = Tuple[ + int, + Tuple[int, ...], + CmpPrePostDevType, + CmpPrePostDevType, + CmpPrePostDevType, + CmpLocalType, +] +VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool] + + +class _Version(NamedTuple): + epoch: int + release: Tuple[int, ...] + dev: Optional[Tuple[str, int]] + pre: Optional[Tuple[str, int]] + post: Optional[Tuple[str, int]] + local: Optional[LocalType] + + +def parse(version: str) -> "Version": + """Parse the given version string. + + >>> parse('1.0.dev1') + + + :param version: The version string to parse. + :raises InvalidVersion: When the version string is not a valid version. + """ + return Version(version) + + +class InvalidVersion(ValueError): + """Raised when a version string is not a valid version. + + >>> Version("invalid") + Traceback (most recent call last): + ... + packaging.version.InvalidVersion: Invalid version: 'invalid' + """ + + +class _BaseVersion: + _key: Tuple[Any, ...] + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?Palpha|a|beta|b|preview|pre|c|rc)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+VERSION_PATTERN = _VERSION_PATTERN
+"""
+A string containing the regular expression used to match a valid version.
+
+The pattern is not anchored at either end, and is intended for embedding in larger
+expressions (for example, matching a version number as part of a file name). The
+regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
+flags set.
+
+:meta hide-value:
+"""
+
+
+class Version(_BaseVersion):
+    """This class abstracts handling of a project's versions.
+
+    A :class:`Version` instance is comparison aware and can be compared and
+    sorted using the standard Python interfaces.
+
+    >>> v1 = Version("1.0a5")
+    >>> v2 = Version("1.0")
+    >>> v1
+    
+    >>> v2
+    
+    >>> v1 < v2
+    True
+    >>> v1 == v2
+    False
+    >>> v1 > v2
+    False
+    >>> v1 >= v2
+    False
+    >>> v1 <= v2
+    True
+    """
+
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    _key: CmpKey
+
+    def __init__(self, version: str) -> None:
+        """Initialize a Version object.
+
+        :param version:
+            The string representation of a version which will be parsed and normalized
+            before use.
+        :raises InvalidVersion:
+            If the ``version`` does not conform to PEP 440 in any way then this
+            exception will be raised.
+        """
+
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        """A representation of the Version that shows all internal state.
+
+        >>> Version('1.0.0')
+        
+        """
+        return f""
+
+    def __str__(self) -> str:
+        """A string representation of the version that can be rounded-tripped.
+
+        >>> str(Version("1.0a5"))
+        '1.0a5'
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join(str(x) for x in self.pre))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        """The epoch of the version.
+
+        >>> Version("2.0.0").epoch
+        0
+        >>> Version("1!2.0.0").epoch
+        1
+        """
+        return self._version.epoch
+
+    @property
+    def release(self) -> Tuple[int, ...]:
+        """The components of the "release" segment of the version.
+
+        >>> Version("1.2.3").release
+        (1, 2, 3)
+        >>> Version("2.0.0").release
+        (2, 0, 0)
+        >>> Version("1!2.0.0.post0").release
+        (2, 0, 0)
+
+        Includes trailing zeroes but not the epoch or any pre-release / development /
+        post-release suffixes.
+        """
+        return self._version.release
+
+    @property
+    def pre(self) -> Optional[Tuple[str, int]]:
+        """The pre-release segment of the version.
+
+        >>> print(Version("1.2.3").pre)
+        None
+        >>> Version("1.2.3a1").pre
+        ('a', 1)
+        >>> Version("1.2.3b1").pre
+        ('b', 1)
+        >>> Version("1.2.3rc1").pre
+        ('rc', 1)
+        """
+        return self._version.pre
+
+    @property
+    def post(self) -> Optional[int]:
+        """The post-release number of the version.
+
+        >>> print(Version("1.2.3").post)
+        None
+        >>> Version("1.2.3.post1").post
+        1
+        """
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> Optional[int]:
+        """The development number of the version.
+
+        >>> print(Version("1.2.3").dev)
+        None
+        >>> Version("1.2.3.dev1").dev
+        1
+        """
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> Optional[str]:
+        """The local version segment of the version.
+
+        >>> print(Version("1.2.3").local)
+        None
+        >>> Version("1.2.3+abc").local
+        'abc'
+        """
+        if self._version.local:
+            return ".".join(str(x) for x in self._version.local)
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        """The public portion of the version.
+
+        >>> Version("1.2.3").public
+        '1.2.3'
+        >>> Version("1.2.3+abc").public
+        '1.2.3'
+        >>> Version("1.2.3+abc.dev1").public
+        '1.2.3'
+        """
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        """The "base version" of the version.
+
+        >>> Version("1.2.3").base_version
+        '1.2.3'
+        >>> Version("1.2.3+abc").base_version
+        '1.2.3'
+        >>> Version("1!1.2.3+abc.dev1").base_version
+        '1!1.2.3'
+
+        The "base version" is the public version of the project without any pre or post
+        release markers.
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        """Whether this version is a pre-release.
+
+        >>> Version("1.2.3").is_prerelease
+        False
+        >>> Version("1.2.3a1").is_prerelease
+        True
+        >>> Version("1.2.3b1").is_prerelease
+        True
+        >>> Version("1.2.3rc1").is_prerelease
+        True
+        >>> Version("1.2.3dev1").is_prerelease
+        True
+        """
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        """Whether this version is a post-release.
+
+        >>> Version("1.2.3").is_postrelease
+        False
+        >>> Version("1.2.3.post1").is_postrelease
+        True
+        """
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        """Whether this version is a development release.
+
+        >>> Version("1.2.3").is_devrelease
+        False
+        >>> Version("1.2.3.dev1").is_devrelease
+        True
+        """
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        """The first item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").major
+        1
+        """
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        """The second item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").minor
+        2
+        >>> Version("1").minor
+        0
+        """
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        """The third item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").micro
+        3
+        >>> Version("1").micro
+        0
+        """
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: Optional[str], number: Union[str, bytes, SupportsInt, None]
+) -> Optional[Tuple[str, int]]:
+
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: Optional[str]) -> Optional[LocalType]:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: Tuple[int, ...],
+    pre: Optional[Tuple[str, int]],
+    post: Optional[Tuple[str, int]],
+    dev: Optional[Tuple[str, int]],
+    local: Optional[LocalType],
+) -> CmpKey:
+
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: CmpPrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: CmpPrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: CmpPrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: CmpLocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/forward_ad.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/forward_ad.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f88a794922fae978b2c6bd201800ce7528c0695
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/forward_ad.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/function.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/function.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0883e3908eebebf774320543e285c1e5829d91fc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/function.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/gradcheck.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/gradcheck.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eac69cb6353a35040e2ed2954c7cd9e2cdf54716
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/gradcheck.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e88283ae08a94367fbfa848a31da77233b850185
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/graph.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_util.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_util.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..014731fa692bd433e65af5e1fde0a495fa7247ad
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_util.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4170fad3eeac788dcb36b6ae1ddbee1b44dc25a1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__init__.py
@@ -0,0 +1 @@
+from .tensor import *  # noqa: F403
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..716ae1db726ad5b397426e0669cfd241ee7ee556
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/tensor.py
@@ -0,0 +1,72 @@
+# mypy: allow-untyped-defs
+import operator
+from functools import reduce
+from typing_extensions import deprecated
+
+import torch
+import torch._utils
+from torch.autograd.function import Function
+
+
+class Type(Function):
+    @staticmethod
+    @deprecated(
+        "`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, "
+        "please use `torch.tensor.to(dtype=dtype)` instead.",
+        category=FutureWarning,
+    )
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, i, dest_type):
+        ctx.input_type = type(i)
+        ctx.input_device = -1 if not i.is_cuda else i.get_device()
+        return i.type(dest_type)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        if ctx.input_device == -1:
+            return grad_output.type(ctx.input_type), None
+        else:
+            with torch.accelerator.device_index(ctx.input_device):
+                return grad_output.type(ctx.input_type), None
+
+
+# TODO: deprecate this
+class Resize(Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, tensor, sizes):
+        ctx.sizes = sizes
+        ctx.numel = reduce(operator.mul, sizes, 1)
+        if tensor.numel() != ctx.numel:
+            raise RuntimeError(
+                (
+                    "requested resize to {} ({} elements in total), "
+                    "but the given tensor has a size of {} ({} elements). "
+                    "autograd's resize can only change the shape of a given "
+                    "tensor, while preserving the number of elements. "
+                ).format(
+                    "x".join(map(str, sizes)),
+                    ctx.numel,
+                    "x".join(map(str, tensor.size())),
+                    tensor.numel(),
+                )
+            )
+        ctx.input_sizes = tensor.size()
+        if tensor.is_quantized:
+            tensor.copy_(tensor)
+            return tensor.contiguous().view(*sizes)
+        if tensor.is_contiguous():
+            result = tensor.new(tensor).contiguous().view(*sizes)
+            return result
+        else:
+            return tensor.contiguous().view(*sizes)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        if grad_output.numel() != ctx.numel:
+            raise AssertionError(
+                f"Expected grad_output to have {ctx.numel} elements, but got {grad_output.numel()}"
+            )
+        return grad_output.contiguous().view(ctx.input_sizes), None
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e74e21d3cef22c0fd459eff5934d4e531d5456d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/utils.py
@@ -0,0 +1,26 @@
+# mypy: allow-untyped-defs
+
+
+def maybe_view(tensor, size, check_same_size=True):
+    if check_same_size and tensor.size() == size:
+        return tensor
+    return tensor.contiguous().view(size)
+
+
+def maybe_unexpand(tensor, old_size, check_same_size=True):
+    if check_same_size and tensor.size() == old_size:
+        return tensor
+    num_unsqueezed = tensor.dim() - len(old_size)
+    expanded_dims = [
+        dim
+        for dim, (expanded, original) in enumerate(
+            zip(tensor.size()[num_unsqueezed:], old_size)
+        )
+        if expanded != original
+    ]
+
+    for _ in range(num_unsqueezed):
+        tensor = tensor.sum(0, keepdim=False)
+    for dim in expanded_dims:
+        tensor = tensor.sum(dim, keepdim=True)
+    return tensor
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..051e1213e7cbeca224decd20c7f7ac6393ccde60
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dbff77bf070595d99bcc4f2def87657cd5856f6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2deb892277632d9439148f6e5849de45a61be1ea
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d48862bef876e16418f248c98eb8db1dc6264ebd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bc5a52c7362d52b7246d093543ac76ddb13814c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e15eecb0617b4c2b01bec9bfa7d08dce15bc65aa
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc709a0aeb6173cf7f3e9e9b4efc8836710520ec
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ffb9ecde05564d36789eff58d84c80b3cdde8e2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63e7ccf9875611077b8b8f6077324586ac3635fa
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/miopen/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/miopen/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5af59cae8f783bdaa767d591ce7a8cbf3341ead5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/miopen/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df2aa1207c58f39959c46aabf41bf153d5884cf6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..becddba492631bdcea2a54131ddcac62a50cad1c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4861d782cfc9c165b29b9b3c3e1e952cd58d1e7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bd61555c08d00ba0280a8de5e4cef448c53d7b0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f9d558a6b90246a75c0aef47f4538d2282f00a7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..771e3e9e70e7cb0af409f6560299996f342f0e0b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54ccc369e3d76b6485fb2ddb72f52a1682836ec3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50e404a2e675e1f0d90b8b6e197fff497744f9de
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9613983795cb5bd51b5ede03fabc63d002d9f8e2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..519d3359cb40c9ba58c330aab4e82d2808b76386
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/csrc/inductor/aoti_runtime/model.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/csrc/inductor/aoti_runtime/model.h
new file mode 100644
index 0000000000000000000000000000000000000000..253c5e917e76bdc8a2adc669404fc8d5c40b6b27
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/csrc/inductor/aoti_runtime/model.h
@@ -0,0 +1,62 @@
+#pragma once
+
+// WARNING: Be careful when adding new includes here. This header will be used
+// in model.so, and should not refer to any aten/c10 headers except the stable
+// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
+// applies to other files under torch/csrc/inductor/aoti_runtime/.
+#include 
+
+namespace torch::aot_inductor {
+
+class AOTInductorModel : public AOTInductorModelBase {
+ public:
+  AOTInductorModel(
+      std::shared_ptr constants_map,
+      std::shared_ptr> constants_array,
+      const std::string& device_str,
+      std::optional cubin_dir);
+
+  std::unordered_map const_run_impl(
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor,
+      bool initialization = false);
+
+  void _const_run_impl(
+      std::vector& output_handles,
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor);
+
+  void run_impl(
+      AtenTensorHandle*
+          input_handles, // array of input AtenTensorHandle; handles
+                         // are stolen; the array itself is borrowed
+      AtenTensorHandle*
+          output_handles, // array for writing output AtenTensorHandle; handles
+                          // will be stolen by the caller; the array itself is
+                          // borrowed
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor);
+
+  template 
+  Outputs run_impl_minimal_arrayref_interface(
+      const Inputs& inputs,
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor);
+
+  static std::unique_ptr Create(
+      std::shared_ptr constants_map,
+      std::shared_ptr> constants_array,
+      const std::string& device_str,
+      std::optional cubin_dir) {
+    return std::make_unique(
+        std::move(constants_map),
+        std::move(constants_array),
+        device_str,
+        std::move(cubin_dir));
+  }
+
+ private:
+  std::unique_ptr kernels_;
+};
+
+} // namespace torch::aot_inductor
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d9686b3f95d75928160c55f4881fc221a45dc0b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61214404ba6039b5bedc5a6c1cbdc45c404b4735
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..578c8b9be90116b7f164f5fefc802eeaa49c0725
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e321d37503617b60ae1355b1f749125fd38440d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..217ac1e8cda0993c1d6c0bc612bb0452bfe5b969
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86cd9b2bba08b37b1908da97a24af7bce8349195
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbdec4fa88d3ba66b12c60011bc058cfbfa4aeae
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ab16e8a1cfa8010884ed5889bcaf3f8babf4a2b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14e198812448363fc8438aba15862e013240e6bf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..473262f8173460d48bcb76a7f07ba1cdb3232d80
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0757804651cf553e4d594e03ca3fc4b8d6d910a7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e0c987ebe5a8c5d369924ce77bd6549427f5445
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f85ef4217eba910a804f7241e93f6fcab6be5e0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..164a204f545e32dba2e1ea974be4b48a9abdb890
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bb89e78ed474ef336d3df66006d347e6939c4ec
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dcb99aac4a1ad6e04288a7eaddbdf155a75c3b07
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..646e51a30133a7617c71a08815a0999af40cdb9b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..931ef2bf90dcf5fab288a7eb8758818111012fbe
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..769f311da3c228c28d1d39033166a34924a80ffd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..430adf9ed7a432024db8d422e0eb1dd2c729b117
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43d201629b7e593b7d34d87c88cd839716b0da51
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4dfd6a0275a43cf3b30acf0c3c65bd13d5c37d7f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97eb747369227cd501aeb0c8c73da3d2f7f1c45b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f114ec4e6ec38e8ce6d87edfbb44de43efe845bc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b110d7a622fe78856e532002a23fa45b73236e6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca432e30f0ff64bc862e6e5badbbcb0fbeeac92d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d06b993fc35642de43e29bf2e9f83895b6d6f726
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9549c6901b94fcbfaace998440e68a951a4ca4dc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89577abf4be1f4b05645c16cb880626501f5bd71
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8ff6341dd821739cade493b1c0ce528b1b711f7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ea7cdc81a9c1befa14a5865b0de65fdf79f4a40
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e3b64a9301ab894c259671fb27109e182d96d8d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..658d8ff899c2d1551469d04cb92febf68a5a8c67
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f174b3d6d78a39b10119b1ab4a6c3a6a521e08a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..127b49f548c894a2fc6c073de9baa0ce178d8976
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43e2ea0a68cde1c4d19c4db09a7fac8ddd04028b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..129000b78bb0d87394629ed55b48a0196f95825f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e7278d3a97ee0c867ab08545d88e6df87f08979
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e89ce9b26f679539fc4aadc92db0b733d21b1f43
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4adca97e5d42e29001c8811f67921ad5267ca8eb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6edb3727bf777f79277e868766e3e55d1fcff99
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf009ee06fb82cb60326d872ce29a1ae43620589
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55e799427597c2f1a6e14811d0f2389a080e0bc9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c301ed8294550726d9c84de4d3535bfa95f9b863
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54731ac9905d8e577c01c97780548c4b1079ace9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86a7a33a3baa434a62ab37d6265a89303d6730c2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c758e7598dd3826404f20456f94240ee658c7f39
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cfee8088e6cd0e60d03c904ec65651d4163dc8f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4140d7f9321b35e0a65174e206a0bb68dc46aba
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d5a7af4022fb8253247f6658ad0f8b503408cb9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f30779ecc28df106658ca80fb103ee3735e5a1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py
@@ -0,0 +1,119 @@
+import re
+from collections.abc import Callable
+from typing import Any, Union
+
+import torch
+from torch.utils._pytree import tree_flatten_with_path, tree_map
+
+
+KeyPath = tuple[Any, ...]
+NonTensorShapeFn = Callable[[Union[int, float]], tuple[Any, ...]]
+
+__all__ = [
+    "normalize_source_name",
+    "module_to_nested_dict",
+    "track_dynamism_across_examples",
+    "clone_and_convert_to_meta",
+]
+
+
+def normalize_source_name(name: str) -> str:
+    # Match attribute access like .x and replace with ['x']
+    return re.sub(r"\.([a-zA-Z_][a-zA-Z0-9_]*)", r"['\1']", name)
+
+
+def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]:
+    """Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys."""
+    self_dict: dict[str, Any] = {}
+
+    self_dict["_parameters"] = {}
+    self_dict["_modules"] = {}
+
+    for attr_name in dir(module):
+        try:
+            if not attr_name.startswith("_") and not callable(
+                getattr(module, attr_name)
+            ):
+                attr_value = getattr(module, attr_name)
+                if (
+                    not isinstance(attr_value, torch.nn.Module)
+                    and isinstance(attr_value, (int, float, torch.Tensor))
+                    and type(attr_value) is not bool
+                ):
+                    self_dict[attr_name] = attr_value
+        except NotImplementedError:
+            # Skip attributes that raise NotImplementedError since they won't
+            # contain any dynamism anyways.
+            continue
+
+    for name, param in module.named_parameters(recurse=False):
+        self_dict["_parameters"][name] = param
+    for name, buffer in module.named_buffers(recurse=False):
+        self_dict["_parameters"][name] = buffer
+
+    for name, submodule in module.named_children():
+        self_dict["_modules"][name] = module_to_nested_dict(submodule)
+
+    return self_dict
+
+
+def track_dynamism_across_examples(
+    example_inputs: list[Any],
+) -> dict[Any, Any]:
+    """
+    This function analyzes a list of example inputs to determine the dynamism of their shapes.
+    It tracks whether the dimensions of tensors or non-tensor values change across
+    different examples. The function returns a dictionary where each key represents
+    a path to a value in the input examples, and the corresponding value is a tuple
+    indicating which dimensions are dynamic (i.e., change across examples). This
+    helps in understanding how the structure of data varies across different instances.
+    """
+    tracking: dict[KeyPath, tuple[list[set[Any]], bool]] = {}
+
+    for ex in example_inputs:
+        if "self" in ex and isinstance(ex["self"], torch.nn.Module):
+            ex["self"] = module_to_nested_dict(ex["self"])
+        leaves_with_paths, _ = tree_flatten_with_path(ex)
+        for key_path, value in leaves_with_paths:
+            if not isinstance(value, (int, float, torch.Tensor)):
+                continue
+            if isinstance(value, torch.Tensor):
+                shape: tuple[int | float, ...] = tuple(value.shape)
+                is_tensor = True
+            else:
+                shape = (value,)
+                is_tensor = False
+            if key_path not in tracking:
+                tracking[key_path] = ([set() for _ in range(len(shape))], is_tensor)
+            else:
+                dim_sets, flag = tracking[key_path]
+                if flag != is_tensor:
+                    pass
+                while len(dim_sets) < len(shape):
+                    dim_sets.append(set())
+            for i, dim in enumerate(shape):
+                tracking[key_path][0][i].add(dim)
+
+    output: dict[Any, Any] = {}
+    for key_path, (dim_sets, _is_tensor) in tracking.items():
+        final_dyn = tuple(len(s) > 1 for s in dim_sets)
+        key_str = "L" + "".join(f"{str(k)}" for k in key_path)
+        key = key_path[0].key  # type: ignore[attr-defined]
+        if key not in output:
+            output[key] = {}
+        output[key][key_str] = final_dyn
+    return output
+
+
+def clone_and_convert_to_meta(example_input: Any) -> Any:
+    """
+    This function takes a list of example inputs and for each tensor, clones it and converts it to device=meta.
+    For non-tensor values, it keeps the reference. It uses pytree to handle nested structures recursively.
+    """
+
+    def transform_fn(value: Any) -> Any:
+        if isinstance(value, torch.Tensor):
+            return value.clone().to(device="meta")
+        return value
+
+    return tree_map(transform_fn, example_input)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/debug.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..b87dee9db9c73f0b4ea1a0a27682a167e125a71d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/debug.py
@@ -0,0 +1,33 @@
+from collections.abc import Sequence
+
+import torch.fx as fx
+
+
+__all__ = ["set_trace"]
+
+
+def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
+    """
+    Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
+    `gm` gets run.
+
+    Args:
+        gm: graph module to insert breakpoint. It is then recompiled for it to
+            take effect.
+
+    Returns:
+        the `gm` with breakpoint inserted.
+    """
+
+    def insert_pdb(body: Sequence[str]) -> list[str]:
+        return ["import pdb; pdb.set_trace()\n", *body]
+
+    with gm.graph.on_generate_code(
+        make_transformer=lambda cur_transform: (
+            # new code transformer to register
+            lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
+        )
+    ):
+        gm.recompile()
+
+    return gm
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd239d78842dd8ba3cbfbf2d03e259a19427489b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py
@@ -0,0 +1,178 @@
+# mypy: allow-untyped-defs
+import itertools
+import operator
+
+import torch
+from torch.fx._symbolic_trace import symbolic_trace
+from torch.fx.node import Node
+from torch.fx.passes.tools_common import legalize_graph
+
+
+def split_result_tensors(
+    result: torch.Tensor, inputs: list[torch.Tensor]
+) -> tuple[torch.Tensor, ...]:
+    """
+    A free function for use in the merge_matmul graph transformation below that
+    splits the output from a merged matmul into the individual results for each
+    input tensor.
+
+    Arguments:
+        result: The merged matmul result tensor.
+        inputs: The list of inputs that were merged into one for the matmul.
+
+    Returns:
+        List of matmul results for each input tensor.
+    """
+    # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
+    # need an int even when tracing
+    if isinstance(result, torch.fx.Proxy):
+        splits = [0] * len(inputs)
+    else:
+        splits = [x.shape[0] for x in inputs]
+
+    # pyrefly: ignore [bad-argument-type]
+    return torch.split(result, splits)
+
+
+def may_depend_on(a: Node, b: Node, search_depth: int = 6):
+    """
+    Determine if one node depends on another in a torch.fx.Graph.
+
+    Arguments:
+        a: The node that may have a dependency on b.
+        b: The node that a may have a dependency on.
+        search_depth: In the case of an indirect dependency, this function
+                        searches upto this many nodes away in search of a
+                        data dependency. If none is found, the function
+                        makes the conservative assumption that there is a
+                        dependency.
+
+    Returns:
+        True if a may depend on b, False if it definitely does not.
+    """
+    # Equivalence is defined as dependence.
+    if a == b:
+        return True
+
+    # If a has no inputs, it cannot depend on b.
+    if len(a.all_input_nodes) == 0:
+        return False
+
+    # If the search depth has been exhausted and no conclusion has been
+    # reached, assume that there is a data dependency.
+    if search_depth == 0:
+        return True
+
+    # Recursively check all inputs of a.
+    for inp in a.all_input_nodes:
+        if may_depend_on(inp, b, search_depth - 1):
+            return True
+
+    return False
+
+
+def are_nodes_independent(nodes: list[Node]):
+    """
+    Check if all of the given nodes are pairwise-data independent.
+
+    Arguments:
+        nodes: The nodes to check for data dependencies.
+
+    Returns:
+        True if any pair in nodes has a data dependency.
+    """
+    # For each pair in nodes:
+    for i, j in itertools.combinations(nodes, 2):
+        if may_depend_on(i, j) or may_depend_on(j, i):
+            return False
+
+    return True
+
+
+def merge_matmul(in_mod: torch.nn.Module):
+    """
+    A graph transformation that merges matrix multiplication operations that share the same right-hand
+    side operand into one large matrix multiplication.
+               ____      _________        _________
+      ----    |    |    |         |     M|  A * C  |
+    M| A  |  T| B  | * K|    C    | =    |---------|
+      ---- ,  |    |    |         |     T|  B * C  |
+       K       ----      ---------        ---------
+                K            R                R
+    """
+    gm = symbolic_trace(in_mod)
+
+    rhs_users: dict[Node, list[Node]] = {}
+    lhs_users: dict[Node, list[Node]] = {}
+
+    # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
+    # the matmul of which they are the LHS/RHS.
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target is not torch.matmul:
+            continue
+
+        lhs, rhs = node.args
+
+        # TODO: Properly handle aliasing caused by get_attr. For now,
+        # use the attribute name as the operand if the node is a
+        # get_attr.
+        lhs = lhs.target if lhs.op == "get_attr" else lhs
+        rhs = rhs.target if rhs.op == "get_attr" else rhs
+
+        lhs_users.setdefault(lhs, []).append(node)
+        rhs_users.setdefault(rhs, []).append(node)
+
+    for rhs, mms in rhs_users.items():
+        # There must be at least matmuls for a merge to make sense.
+        if len(mms) < 2:
+            continue
+
+        # All matmuls must not depend on each other directly or indirectly
+        # in order for the merge to be possible.
+        if not are_nodes_independent(mms):
+            continue
+
+        lhs_vals = [mm.args[0] for mm in mms]
+
+        # Merge the matmul.
+        # Collect a list of LHS operands and the single RHS operand.
+        lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
+        rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
+
+        # Concatenate all the LHS operands.
+        merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
+
+        # Multiply the concatenated LHS operands with the one RHS. This will produce
+        # the same results as all the individual matmuls involving rhs in the original graph,
+        # but they will all be concatenated together.
+        merge_mm = gm.graph.call_function(
+            torch.matmul,
+            (
+                merge_mm_cat,
+                rhs,
+            ),
+            {},
+        )
+
+        # Split the result of the merged matmul using the shapes of the LHS operands
+        # to ascertain how large each chunk should be.
+        merge_mm_split = gm.graph.call_function(
+            split_result_tensors, (merge_mm, lhs), {}
+        )
+        merge_mm_res = [
+            gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
+            for out in range(len(lhs))
+        ]
+
+        # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
+        for old, new in zip(mms, merge_mm_res):
+            old.replace_all_uses_with(new)
+            gm.graph.erase_node(old)
+
+        # All of the new nodes created above were inserted at the end, so we need to sort
+        # the nodes topologically to make sure all definitions precede uses.
+        legalize_graph(gm)
+
+    gm.recompile()
+    gm.graph.lint()
+    return gm
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0782ba5affc9cbbe6b55fbba131066a35f331f5a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
@@ -0,0 +1,1322 @@
+# mypy: ignore-errors
+import copy
+import itertools
+from collections.abc import Callable
+
+from torch.fx.experimental.migrate_gradual_types.constraint import (
+    ApplyBroadcasting,
+    BinConstraintD,
+    CalcConv,
+    CalcMaxPool,
+    CalcProduct,
+    CanReshape,
+    Conj,
+    Constraint,
+    DGreatestUpperBound,
+    Disj,
+    DVar,
+    F,
+    GetItem,
+    GetItemTensor,
+    IndexSelect,
+    Prod,
+    T,
+    TGreatestUpperBound,
+    Transpose,
+    TVar,
+)
+from torch.fx.experimental.migrate_gradual_types.constraint_generator import (
+    BinConstraintT,
+    MAX_TENSOR_RANK,
+)
+from torch.fx.experimental.migrate_gradual_types.operation import (
+    op_add,
+    op_consistency,
+    op_div,
+    op_eq,
+    op_leq,
+    op_matching,
+    op_mod,
+    op_mul,
+    op_neq,
+    op_precision,
+    op_sub,
+)
+from torch.fx.experimental.migrate_gradual_types.util import (
+    gen_dvar,
+    gen_nat_constraints,
+    gen_tensor_dims,
+)
+from torch.fx.tensor_type import Dyn, TensorType
+
+
+_TRANSFORMATION_RULES: dict[Constraint, Callable] = {}
+
+
+def register_transformation_rule(call_target):
+    def register(fn):
+        if call_target in _TRANSFORMATION_RULES:
+            raise RuntimeError(
+                f"Transformation rule already registered for {call_target}!"
+            )
+        _TRANSFORMATION_RULES[call_target] = fn
+        return fn
+
+    return register
+
+
+def valid_index(index, dims):
+    """
+    Given a list of dimensions, checks if an index is valid in the list
+    """
+    try:
+        dims[index]
+        return T()
+    except IndexError:
+        return F()
+
+
+@register_transformation_rule(Transpose)
+def transform_transpose(constraint, counter):
+    """
+    Similar to a sequence of two index-selects
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    is_valid_index1 = valid_index(constraint.index1, dims)
+    is_valid_index2 = valid_index(constraint.index2, dims)
+    new_dims = copy.deepcopy(dims)
+    nat_constraints = gen_nat_constraints(dims)
+
+    if is_valid_index1 == T() and is_valid_index2 == T():
+        new_dims[constraint.index1] = dims[constraint.index2]
+        new_dims[constraint.index2] = dims[constraint.index1]
+
+    transformed_constraint = Conj(
+        [
+            BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+            *nat_constraints,
+            is_valid_index1,
+            is_valid_index2,
+            BinConstraintT(constraint.output, TensorType(new_dims), op_eq),
+        ]
+    )
+    return transformed_constraint, counter
+
+
+@register_transformation_rule(IndexSelect)
+def transform_index_select(constraint, counter):
+    """
+    The constraints consider the given tensor size, checks if the index is valid
+    and if so, generates a constraint for replacing the input dimension
+    with the required dimension
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    is_valid_index = valid_index(constraint.index, dims)
+    nat_constraints = gen_nat_constraints(dims)
+
+    # if the index is valid then replace the input dimension with the new dimension
+    # otherwise the dimension will not be replaced and the clause will contain False
+    if is_valid_index == T():
+        new_dims = copy.deepcopy(dims)
+        new_dims[constraint.index] = constraint.dim_replace
+
+    transformed_constraint = Conj(
+        [
+            BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+            *nat_constraints,
+            is_valid_index,
+            BinConstraintT(constraint.output, TensorType(new_dims), op_eq),
+        ]
+    )
+
+    # print(constraints)
+    return transformed_constraint, counter
+
+
+@register_transformation_rule(GetItem)
+def transform_get_item(constraint, counter):
+    """
+    generate an equality of the form:
+    t = [a1, ..., an]
+    then generate constraints that check if the given index is valid
+    given this particular tensor size.
+    If the index is valid, generate a constraint to get the item
+    Note that we already handled the Dyn input case in the previous
+    step.
+    Args:
+        constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
+        counter: variable tracking
+    Returns: simplified constraints for GetItem
+
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    nat_constraints = gen_nat_constraints(dims)
+
+    is_valid_index = valid_index(constraint.index, dims)
+
+    all_constraints = [
+        BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+        *nat_constraints,
+        is_valid_index,
+    ]
+
+    # if the index is valid, we generate a constraint for getting an item
+    # otherwise this clause will have been UNSAT due to the wrong index
+    if is_valid_index == T():
+        all_constraints.append(
+            BinConstraintD(constraint.res, dims[constraint.index], op_eq)
+        )
+
+    return Conj(all_constraints), counter
+
+
+def valid_index_tensor(index, dims):
+    """
+    if the slice instances exceed the length of the dimensions
+    then this is a type error so we return False
+    """
+    slice_count = 0
+    for s in index:
+        if isinstance(s, slice):
+            slice_count += 1
+    if slice_count > len(dims):
+        return F()
+    else:
+        return T()
+
+
+@register_transformation_rule(GetItemTensor)
+def transform_get_item_tensor(constraint, counter):
+    """
+    When the index is a tuple, then the output will be a tensor
+    TODO: we have to check if this is the case for all HF models
+
+    The cases we are covering here are a tuple with one of:
+     - slice with default argument
+     - None
+
+     None appends 1 to the input tensor dimensions
+     so each occurrence of 'None' increases the rank by 1
+
+     slice with default arguments does not change the rank
+    """
+    assert isinstance(constraint.index_tuple, tuple)
+
+    # generate a result tensor of the expected size
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    nat_constraints = gen_nat_constraints(dims)
+
+    # generate a place-holder list of the right rank
+    # where "slice" does not contribute to the rank and "None" does
+    none_c = constraint.index_tuple.count(None)
+    resulting_tensor_dims = (none_c + len(dims)) * [None]
+
+    dim_index = 0
+    for i in range(len(constraint.index_tuple)):
+        # append 1 to the right location of the resulting tensor
+        if constraint.index_tuple[i] is None:
+            resulting_tensor_dims[i] = 1
+
+        elif constraint.index_tuple[i] == slice(None, None, None):
+            pass
+
+        else:
+            raise NotImplementedError("Method not yet implemented")
+
+    # append the remaining dimensions to the right location
+    dim_index = 0
+    for i in range(len(resulting_tensor_dims)):
+        if resulting_tensor_dims[i] is None:
+            resulting_tensor_dims[i] = dims[dim_index]
+            dim_index += 1
+
+    # check if the index is valid
+    is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
+
+    # check if the resulting tensor is within bounds
+    if len(resulting_tensor_dims) > 4:
+        return F(), counter
+
+    else:
+        constraints = [
+            BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+            BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
+            *nat_constraints,
+            is_valid_index,
+        ]
+        return Conj(constraints), counter
+
+
+@register_transformation_rule(BinConstraintT)
+def generate_binconstraint_t(constraint, counter):
+    """
+    Transform binary constraints for tensors
+    """
+
+    # precision constraints
+    if constraint.op == op_precision:
+        if constraint.lhs == Dyn:
+            return T(), counter
+        elif isinstance(constraint.lhs, TensorType):
+            is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
+            if is_fully_static:
+                return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
+            else:
+                new_dims = []
+
+                for _ in range(len(constraint.lhs.__args__)):
+                    dim, counter = gen_dvar(counter)
+                    new_dims.append(dim)
+
+                new_dim_constraints = (
+                    [
+                        BinConstraintD(old_dim, new_dim, op_precision)
+                        for new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)
+                    ]
+                    + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)]
+                    + [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims]
+                )
+                return Conj(new_dim_constraints), counter
+
+    # matching
+    elif constraint.op == op_matching:
+        assert isinstance(constraint.rhs, TensorType)
+        d1 = constraint.rhs.__args__[0]
+        d2 = constraint.rhs.__args__[1]
+        d3 = constraint.rhs.__args__[2]
+        d4 = constraint.rhs.__args__[3]
+
+        conj = [
+            BinConstraintT(constraint.lhs, Dyn, op_eq),
+            BinConstraintD(d1, Dyn, op_eq),
+            BinConstraintD(d2, Dyn, op_eq),
+            BinConstraintD(d3, Dyn, op_eq),
+            BinConstraintD(d4, Dyn, op_eq),
+        ]
+        return (
+            Disj(
+                [
+                    Conj(conj),
+                    BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq),
+                ]
+            ),
+            counter,
+        )
+
+    elif constraint.op == op_consistency:
+        c_dyn = Disj(
+            [
+                BinConstraintT(constraint.lhs, Dyn, op_eq),
+                BinConstraintT(constraint.rhs, Dyn, op_eq),
+            ]
+        )
+        (
+            (
+                c_tensor_1,
+                c_tensor_2,
+                c_tensor_3,
+                c_tensor_4,
+            ),
+            counter,
+        ) = gen_consistency_constraints(constraint, counter)
+
+        return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
+
+    elif constraint.op == op_leq:
+        assert isinstance(constraint.rhs, int)
+        disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
+        for i in range(1, constraint.rhs + 1):
+            dims = []
+            for _ in range(1, i + 1):
+                dim_var, counter = gen_dvar(counter)
+                dims.append(dim_var)
+            disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
+        return Disj(disj), counter
+    else:
+        return constraint, counter
+
+
+@register_transformation_rule(BinConstraintD)
+def generate_binconstraint_d(constraint, counter):
+    """
+    Transform binary constraints for dimensions
+    """
+    if constraint.op == op_precision:
+        if isinstance(constraint.lhs, int):
+            return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
+        elif constraint.lhs == Dyn:
+            return T(), counter
+
+    elif constraint.op == op_consistency:
+        return (
+            Disj(
+                [
+                    BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
+                    BinConstraintD(constraint.rhs, Dyn, op_eq),
+                    BinConstraintD(constraint.lhs, Dyn, op_eq),
+                ]
+            ),
+            counter,
+        )
+
+    else:
+        return constraint, counter
+
+
+@register_transformation_rule(Conj)
+def generate_conj(constraint, counter):
+    """
+    Transform conjunctions
+    """
+    new = []
+    for c in constraint.conjucts:
+        new_c, counter = transform_constraint(c, counter)
+        new.append(new_c)
+    return Conj(new), counter
+
+
+@register_transformation_rule(Disj)
+def generate_disj(constraint, counter):
+    """
+    Transform disjunctions
+    """
+    new = []
+    for c in constraint.disjuncts:
+        new_c, counter = transform_constraint(c, counter)
+        new.append(new_c)
+    return Disj(new), counter
+
+
+@register_transformation_rule(TGreatestUpperBound)
+def generate_gub(constraint, counter):
+    """
+    Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
+    on dimensions
+    """
+    c1 = Conj(
+        [
+            Disj(
+                [
+                    BinConstraintT(constraint.rhs1, Dyn, op_eq),
+                    BinConstraintT(constraint.rhs2, Dyn, op_eq),
+                ]
+            ),
+            BinConstraintT(constraint.res, Dyn, op_eq),
+        ]
+    )
+
+    [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
+
+    return Disj([c1, c2, c3, c4, c5]), counter
+
+
+@register_transformation_rule(DGreatestUpperBound)
+def generate_d_gub(constraint, counter):
+    """
+    Transform greatest upper bound for dimensions into equality constraints
+    """
+    c1 = Conj(
+        [
+            BinConstraintD(constraint.rhs1, Dyn, op_eq),
+            BinConstraintD(constraint.res, constraint.rhs2, op_eq),
+        ]
+    )
+    c2 = Conj(
+        [
+            BinConstraintD(constraint.rhs2, Dyn, op_eq),
+            BinConstraintD(constraint.res, constraint.rhs1, op_eq),
+        ]
+    )
+    c3 = Conj(
+        [
+            BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq),
+            BinConstraintD(constraint.res, constraint.rhs1, op_eq),
+        ]
+    )
+    return Disj([c1, c2, c3]), counter
+
+
+@register_transformation_rule(CalcConv)
+def generate_calc_conv(constraint, counter):
+    d, counter = gen_tensor_dims(4, counter)
+    conv_result = TensorType([d[0], d[1], d[2], d[3]])
+
+    # the convolution result is a tensor of size 4
+    c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
+
+    # the second dimension of the output is equal to the output channels
+    c2 = Conj(
+        [
+            BinConstraintD(d[1], constraint.c_out, op_eq),
+            BinConstraintD(d[1], Dyn, op_neq),
+        ]
+    )
+
+    # the input corresponds to the output in the first dimension of the convolution
+    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
+
+    c4, c5 = calc_last_two_dims(constraint, d)
+
+    leq_constraints = Conj(
+        [
+            BinConstraintD(0, d[0], op_leq),
+            BinConstraintD(0, d[1], op_leq),
+            BinConstraintD(0, d[2], op_leq),
+            BinConstraintD(0, d[3], op_leq),
+        ]
+    )
+
+    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
+
+
+@register_transformation_rule(CalcMaxPool)
+def generate_calc_maxpool(constraint, counter):
+    """
+    Transform maxpool constraints
+    """
+    d, counter = gen_tensor_dims(4, counter)
+    maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
+
+    # the maxpool result is a tensor of size 4
+    c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
+
+    # the input corresponds to the output in the first and second dimension of maxpool
+    c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
+    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
+    c4, c5 = calc_last_two_dims(constraint, d)
+
+    leq_constraints = Conj(
+        [
+            BinConstraintD(0, d[0], op_leq),
+            BinConstraintD(0, d[1], op_leq),
+            BinConstraintD(0, d[2], op_leq),
+            BinConstraintD(0, d[3], op_leq),
+        ]
+    )
+
+    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
+
+
+@register_transformation_rule(CalcProduct)
+def generate_calc_product(constraint, counter):
+    """
+    Transform flatten constraints
+    """
+    start = constraint.start
+    end = constraint.end
+    dims = constraint.dims_to_flatten
+    flattened = constraint.flattened
+    n = len(constraint.dims_to_flatten)
+
+    # this will be evaluated right here
+    boundary_check = 0 <= start and start < end and end <= n
+
+    c_boundary = T() if boundary_check else F()
+
+    lhs = dims[0:start]
+    rhs = dims[end:]
+    mid = dims[start:end]
+
+    all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
+
+    all_constraints = []
+
+    for p in all_possibilities:
+        p = list(p)
+        # this tells us there is a dynamic variable
+        contains_dyn = not all(constraint.op == op_neq for constraint in p)
+        if contains_dyn:
+            mid_var = [Dyn]
+            total_constraints = lhs + mid_var + rhs
+            if len(total_constraints) > 4:
+                all_constraints.append(F())
+            else:
+                all_constraints.append(
+                    Conj(
+                        [
+                            BinConstraintT(
+                                flattened, TensorType(lhs + mid_var + rhs), op_eq
+                            )
+                        ]
+                        + p
+                    )
+                )
+        else:
+            new_var, counter = gen_dvar(counter)
+            mid_eq_prod = Conj(
+                [
+                    BinConstraintD(new_var, Prod(mid), op_eq),
+                    BinConstraintD(new_var, Dyn, op_neq),
+                ]
+            )
+            mid_var = [new_var]
+            total_constraints = lhs + mid_var + rhs
+            if len(total_constraints) > 4:
+                all_constraints.append(F())
+            else:
+                all_constraints.append(
+                    Conj(
+                        [
+                            BinConstraintT(
+                                flattened, TensorType(lhs + mid_var + rhs), op_eq
+                            ),
+                            mid_eq_prod,
+                        ]
+                        + p
+                    )
+                )
+
+    return Conj([Disj(all_constraints), c_boundary]), counter
+
+
+@register_transformation_rule(CanReshape)
+def generate_reshape(constraint, counter):
+    """
+    Transform reshape constraints
+    """
+    d, counter = gen_tensor_dims(4, counter)
+
+    d1 = d[0]
+    d2 = d[1]
+    d3 = d[2]
+    d4 = d[3]
+
+    target = constraint.target.__args__
+
+    is_fully_static = all(d != Dyn for d in target)
+
+    # dynamic tensor
+    c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
+    c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
+    c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
+    c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
+    c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
+
+    d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
+    d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
+
+    d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
+    d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
+
+    d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
+    d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
+
+    d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
+    d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
+
+    nat_d1 = BinConstraintD(0, d1, op_leq)
+    nat_d2 = BinConstraintD(0, d2, op_leq)
+    nat_d3 = BinConstraintD(0, d3, op_leq)
+    nat_d4 = BinConstraintD(0, d4, op_leq)
+
+    if is_fully_static:
+        # size 1 tensor
+        c3_tensor1 = Disj(
+            [d1_eq_dyn, (Conj([d1_neq_dyn, BinConstraintD(d1, Prod(target), op_eq)]))]
+        )
+        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
+
+        # size 2 tensor
+        all_tensor_2 = Conj(
+            [c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]
+        )
+
+        # size 3 tensor
+        all_tensor_3 = Conj(
+            [c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]
+        )
+
+        # size 4 tensor
+        all_tensor_4 = Conj(
+            [c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]
+        )
+
+        return (
+            Conj(
+                [
+                    Disj(
+                        [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]
+                    ),
+                    nat_d1,
+                    nat_d2,
+                    nat_d3,
+                    nat_d4,
+                ]
+            ),
+            counter,
+        )
+
+    # then there must be exactly one occurrence of dyn
+    else:
+        new_target = [n for n in target if n != Dyn]
+
+        # tensor 1
+        c3_tensor1 = Disj(
+            [d1_eq_dyn, (Conj([d1_neq_dyn, is_dim_div_by_target(new_target, d1)]))]
+        )
+        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
+
+        # tensor 2
+        c21 = Disj([d1_eq_dyn, d2_eq_dyn])
+        c22 = Conj(
+            [d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]
+        )
+        all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
+
+        # tensor 3
+        c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
+        c32 = Conj(
+            [
+                d1_neq_dyn,
+                d2_neq_dyn,
+                d3_neq_dyn,
+                is_dim_div_by_target(new_target, Prod([d1, d2, d3])),
+            ]
+        )
+        all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
+
+        # tensor 4
+        c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
+        c42 = Conj(
+            [
+                d1_neq_dyn,
+                d2_neq_dyn,
+                d3_neq_dyn,
+                d4_neq_dyn,
+                is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4])),
+            ]
+        )
+        all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
+
+        return (
+            Conj(
+                [
+                    Disj(
+                        [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]
+                    ),
+                    nat_d1,
+                    nat_d2,
+                    nat_d3,
+                    nat_d4,
+                ]
+            ),
+            counter,
+        )
+
+
+@register_transformation_rule(ApplyBroadcasting)
+def generate_broadcasting(constraint, counter):
+    """
+    Transform broadcasting constraints
+    """
+    e11, e12 = constraint.res1, constraint.res2
+    e1, e2 = constraint.input1, constraint.input2
+
+    e1_dyn = BinConstraintT(e1, Dyn, op_eq)
+    e2_dyn = BinConstraintT(e2, Dyn, op_eq)
+
+    # Introduce dimensions
+    e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
+    e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
+
+    # dyn possibility
+    e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
+    e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
+
+    # tensor possibility
+    # generate dimensions to create tensors of size 1
+    final_tensor_1_constraint, _, _, nat_dims_1, counter = gen_broadcasting_constraints(
+        e1, e2, e11, e12, 1, counter
+    )
+
+    # generate dimensions to create tensors of size 2
+    (
+        final_tensor_2_constraint_no_padding,
+        final_tensor_2_constraint_padding_arg1,
+        final_tensor_2_constraint_padding_arg2,
+        nat_dims_2,
+        counter,
+    ) = gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
+
+    # generate dimensions to create tensors of size 3
+    (
+        final_tensor_3_constraint_no_padding,
+        final_tensor_3_constraint_padding_arg1,
+        final_tensor_3_constraint_padding_arg2,
+        nat_dims_3,
+        counter,
+    ) = gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
+
+    # generate dimensions to create tensors of size 4
+    (
+        final_tensor_4_constraint_no_padding,
+        final_tensor_4_constraint_padding_arg1,
+        final_tensor_4_constraint_padding_arg2,
+        nat_dims_4,
+        counter,
+    ) = gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
+
+    final_result = Disj(
+        [
+            e1_dyn_constraint,
+            e2_dyn_constraint,
+            final_tensor_1_constraint,
+            final_tensor_2_constraint_no_padding,
+            final_tensor_2_constraint_padding_arg1,
+            final_tensor_2_constraint_padding_arg2,
+            final_tensor_3_constraint_no_padding,
+            final_tensor_3_constraint_padding_arg1,
+            final_tensor_3_constraint_padding_arg2,
+            final_tensor_4_constraint_no_padding,
+            final_tensor_4_constraint_padding_arg1,
+            final_tensor_4_constraint_padding_arg2,
+        ]
+    )
+
+    return (
+        Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]),
+        counter,
+    )
+
+
+def transform_constraint(constraint: Constraint, counter: int):
+    """
+    Transforms a constraint into a simpler constraint.
+    Ex: precision and consistency are transformed to equality
+    Args:
+        constraint: constraint to be transformed
+        counter: for variable tracking
+
+    Returns: Constraint
+
+    """
+    if type(constraint) in _TRANSFORMATION_RULES:
+        return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
+
+    else:
+        return constraint, counter
+
+
+def calc_last_two_dims(constraint, d: list[DVar]):
+    """
+    Generates constraints for the last two dimensions of a convolution or a maxpool output
+    Args:
+        constraint: CalcConv or CalcMaxPool
+        d: The list of output dimensions
+
+    Returns: Constraints for calculating the last two dimensions of the output
+
+    """
+
+    assert isinstance(constraint, (CalcConv, CalcMaxPool))
+
+    b3 = constraint.matching_constraint[2]
+    b4 = constraint.matching_constraint[3]
+
+    b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
+    b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
+
+    d3_not_dyn = Conj(
+        [BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]
+    )
+    d4_not_dyn = Conj(
+        [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]
+    )
+
+    # transform parameters into tuples in case they are not already
+    padding = (
+        (constraint.padding, constraint.padding)
+        if isinstance(constraint.padding, int)
+        else constraint.padding
+    )
+    kernel = (
+        (constraint.kernel, constraint.kernel)
+        if isinstance(constraint.kernel, int)
+        else constraint.kernel
+    )
+    stride = (
+        (constraint.stride, constraint.stride)
+        if isinstance(constraint.stride, int)
+        else constraint.stride
+    )
+    dilation = (
+        (constraint.dilation, constraint.dilation)
+        if isinstance(constraint.dilation, int)
+        else constraint.dilation
+    )
+
+    f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
+    f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
+    f3 = BinConstraintD(
+        BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div
+    )
+    f4 = BinConstraintD(f3, 1, op_add)
+
+    c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
+
+    f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
+    f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
+    f33 = BinConstraintD(
+        BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div
+    )
+    f44 = BinConstraintD(f33, 1, op_add)
+
+    c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
+
+    return c4, c5
+
+
+def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]):
+    """
+    Generate all possibilities of being equal or not equal to dyn for my_list
+    Args:
+        my_list: List of tensor dimensions
+
+    Returns: A list of a list of constraints. Each list of constraints corresponds to
+    one possibility about the values of the dimension variables
+    """
+    # generate all possibilities of being equal or not equal to dyn for my_list
+    eq_possibilities = [
+        BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))
+    ]
+    neq_possibilities = [
+        BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))
+    ]
+
+    d_possibilities = [list(i) for i in zip(eq_possibilities, neq_possibilities)]
+    all_possibilities = list(itertools.product(*d_possibilities))
+    return all_possibilities
+
+
+def is_target_div_by_dim(target: list[int], dim: list[DVar]):
+    """
+    Generate constraints to check if the target dimensions are divisible by the input dimensions
+    Args:
+        target: Target dimensions
+        dim: Input dimensions
+
+    Returns: Constraints to check divisibility
+
+    """
+    return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
+
+
+def is_dim_div_by_target(target: list[int], dim: list[DVar]):
+    """
+    Generate constraints to check if the input dimensions is divisible by the target dimensions
+    Args:
+        target: Target dimensions
+        dim:  Input dimensions
+
+    Returns: Constraints to check divisibility
+
+    """
+    return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
+
+
+def gen_all_reshape_possibilities(list_of_dims, target):
+    """
+    Consider all possibilities what the input dimensions could be (number or dynamic)
+    Then generate the appropriate constraints using multiplication or mod depending on the possibility
+    The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
+    for the input. Target is fixed because at most one dimension could be dyn.
+    We have different cases for this.
+
+    Args:
+        list_of_dims: The input list of dimensions
+        target: The tensor we want to reshape to
+
+    Returns: A disjunction of transformed reshape constraints
+
+    """
+    all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
+
+    all_constraints = []
+
+    for p in all_possibilities:
+        to_multiply = []
+
+        p = list(p)
+
+        for constraint in p:
+            assert isinstance(constraint, BinConstraintD)
+            if constraint.op == op_neq:
+                to_multiply.append(constraint.lhs)
+
+        if not to_multiply:
+            all_constraints.append(Conj(p))
+
+        elif len(to_multiply) < len(list_of_dims):
+            all_constraints.append(
+                Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])
+            )
+        else:
+            all_constraints.append(
+                Conj(p + [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)])
+            )
+
+    return Disj(all_constraints)
+
+
+def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
+    """
+    Apply broadcasting to the 'index' dimension of tensor_input1.
+    Args:
+        tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
+        tensor_input2: represents the second input
+        res1: broadcasted result 1
+        res2: broadcasted result 2
+        index: the index to broadcast
+        padding: If padding was used, then tensor_input1[index] does not exist
+
+    Returns:
+
+    """
+    if tensor_input1[index] is None:
+        assert padding
+
+    if not padding:
+        # then the inputs are the same length so they all have dimensions at "index"
+        return Conj(
+            [
+                BinConstraintD(tensor_input1[index], 1, op_eq),
+                BinConstraintD(res1[index], res2[index], op_eq),
+                BinConstraintD(res2[index], tensor_input2[index], op_eq),
+            ]
+        )
+
+    else:
+        # we don't set the input dimension to 1, since it doesn't exist.
+        return Conj(
+            [
+                BinConstraintD(res1[index], res2[index], op_eq),
+                BinConstraintD(res2[index], tensor_input2[index], op_eq),
+            ]
+        )
+
+
+def apply_padding(
+    e1_var: TVar,
+    e11: BinConstraintT,
+    e2: BinConstraintT,
+    e12: BinConstraintT,
+    d2: list[DVar],
+    d11: list[DVar],
+    d12: list[DVar],
+    counter: int,
+):
+    """
+    We are considering the possibility where one input has less dimensions than
+    another input, so we apply padding to the broadcasted results
+
+    Args:
+        e1_var: Variable representing the first input where padding will be
+        e11: constraint of the form e11 = Tensortype[d1, ..., dn]
+        e2:  constraint of the form e2 = Tensortype[d1, ..., dn]
+        e12: constraint of the form e11 = Tensortype[d1, ..., dn]
+        d2: Tensor variables for the second input
+        d11: Tensor variables for the broadcasted first input
+        d12: Tensor variables for the broadcasted second input
+        counter: variable tracking
+
+    Returns: A new constraint whose goal is to apply padding to the broadcasted result
+
+    """
+
+    res = []
+
+    # pad the shorter input with None so we can pass it to the broadcasting helper function
+    for i in range(1, len(d2)):
+        d1, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
+
+        e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
+
+        simulate_padding = [None] * (len(d2) - i)
+
+        assert len(simulate_padding + d1) == len(d2)
+
+        # for every padding size, we also consider broadcasting
+        broadcast_padding = [
+            broadcast_dim(simulate_padding, d2, d11, d12, j, True)
+            for j in range(len(d2) - i)
+        ]
+
+        # we consider the possibilities for broadcasting for every dimension. Since we already
+        # padded d1, we do not consider it while broadcasting
+        all_broadcasting_possibilities = (
+            generate_all_broadcasting_possibilities_no_padding(
+                d1, d2[(len(d2) - i) :], d11[(len(d2) - i) :], d12[(len(d2) - i) :]
+            )
+        )
+        # combine all constraints into a conjunction
+        c = Conj(
+            [
+                e1,
+                e11,
+                e2,
+                e12,
+                *broadcast_padding,
+                all_broadcasting_possibilities,
+                *nat_constraints,
+            ]
+        )
+        res.append(c)
+
+    return Disj(res), counter
+
+
+def no_broadcast_dim_with_index(
+    d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int
+):
+    """
+    Args:
+        d1: input 1
+        d2: input 2
+        d3: simulated broadcasting for input 1
+        d4: simulated broadcasting for input 2
+        i: the rank of the resulting tensor addition
+
+    Returns: Constraints for when no broadcasting occurs
+    """
+    return Conj(
+        [
+            Disj(
+                [
+                    Conj(
+                        [
+                            BinConstraintD(d1[i], 1, op_eq),
+                            BinConstraintD(d2[i], 1, op_eq),
+                        ]
+                    ),
+                    Conj(
+                        [
+                            BinConstraintD(d1[i], 1, op_neq),
+                            BinConstraintD(d2[i], 1, op_neq),
+                        ]
+                    ),
+                ]
+            ),
+            BinConstraintD(d1[i], d3[i], op_eq),
+            BinConstraintD(d2[i], d4[i], op_eq),
+        ]
+    )
+
+
+def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
+    """
+    Generate lists of DVar to represent tensor dimensions
+    Args:
+        num_tensors: the required number of tensors
+        dim_size: the number of dimensions for each tensor
+        counter: variable tracking
+
+    Returns: A list of a list of tensor dimensions
+
+    """
+    res = []
+
+    for _ in range(num_tensors):
+        dims, counter = gen_tensor_dims(dim_size, counter)
+        res.append(dims)
+
+    return res, counter
+
+
+def create_equality_constraints_for_broadcasting(
+    e1: TVar,
+    e2: TVar,
+    e11: TVar,
+    e12: TVar,
+    d1: list[DVar],
+    d2: list[DVar],
+    d11: list[DVar],
+    d12: list[DVar],
+):
+    """
+    Create equality constraints for when no broadcasting occurs
+    Args:
+        e1: Input 1
+        e2: Input 2
+        e11: Broadcasted input 1
+        e12: Broadcasted input 2
+        d1: Variables that store dimensions for e1
+        d2: Variables that store dimensions for e2
+        d11: Variables that store dimensions for e11
+        d12: Variables that store dimensions for e22
+
+    Returns: Four equality constraints
+
+    """
+
+    e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
+    e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
+    e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
+    e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
+    return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
+
+
+def gen_consistency_constraints(constraint: Constraint, counter: int):
+    """
+    Args:
+        constraint: Consistency constraint on tensors
+        counter: for variable tracking
+
+    Returns: Equality and consistency constraints on dimensions
+
+    """
+
+    all_constraints = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
+        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
+
+        c_tensor_i = Conj(
+            [
+                BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
+                BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq),
+            ]
+            + [
+                BinConstraintD(d1, d2, op_consistency)
+                for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)
+            ]
+            + nat_constraints
+        )
+
+        all_constraints.append(c_tensor_i)
+
+    return all_constraints, counter
+
+
+def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
+    """
+    Args:
+        constraint: Greatest upper bound on tensors
+        counter: variable tracking
+
+    Returns: A set of equality constraints and DGreatestUpperBound constraints
+
+    """
+
+    all_constraints = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        c = []
+        dims1, counter = gen_tensor_dims(i, counter)
+        c1tensor = TensorType(dims1)
+
+        dims2, counter = gen_tensor_dims(i, counter)
+        c2tensor = TensorType(dims2)
+
+        dims3, counter = gen_tensor_dims(i, counter)
+        c3tensor = TensorType(dims3)
+
+        c += [
+            BinConstraintT(constraint.rhs1, c1tensor, op_eq),
+            BinConstraintT(constraint.rhs2, c2tensor, op_eq),
+            BinConstraintT(constraint.res, c3tensor, op_eq),
+        ] + gen_nat_constraints(dims1 + dims2 + dims3)
+
+        assert (
+            len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
+        )
+        for i in range(len(c3tensor.__args__)):
+            c.append(
+                DGreatestUpperBound(
+                    c3tensor.__args__[i], c1tensor.__args__[i], c2tensor.__args__[i]
+                )
+            )
+
+        all_constraints.append(Conj(c))
+    return all_constraints, counter
+
+
+def generate_all_broadcasting_possibilities_no_padding(
+    d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar]
+):
+    """
+    Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
+    We look at all combinations for all dimensions in d1 and d2
+    Args:
+        d1: input1 dimensions
+        d2: input2 dimensions
+        d11: broadcasted input1 dimensions
+        d12: broadcasted input2 dimensions
+
+    Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
+
+    """
+
+    size = len(d1)
+
+    res2 = []
+
+    for i in range(size):
+        t1 = broadcast_dim(d1, d2, d11, d12, i)
+        t2 = broadcast_dim(d2, d1, d12, d11, i)
+        t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
+
+        res2.append(Disj([t1, t2, t3]))
+
+    return Conj(res2)
+
+
+def gen_broadcasting_constraints(
+    e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int
+):
+    """
+    Simulates broadcasting on e1 and e2 and returns the results
+    respectively in e11 and e12. Because of gradual types,
+    e1 and e2 may not be equal. Similarly, e11 and e12 may not
+    be equal. e11 and e12 should be guaranteed to be consistent
+    as they represent the shapes of the tensors to be added after
+    broadcasting.
+    Args:
+        e1: TVar representing the type of input 1
+        e2: TVar representing the type of input 2
+        e11: TVar representing the representing broadcasted input 1
+        e12: TVar representing the representing broadcasted input 2
+        i: The rank of the resulting type of addition
+        counter: for variable tracking
+
+    Returns: Simplified broadcasting constraints
+
+    """
+    dims, counter = gen_lists_of_dims(4, i, counter)
+    [d1, d2, d3, d4] = dims
+    nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
+
+    initialize_tensors_constraints = create_equality_constraints_for_broadcasting(
+        e1, e2, e11, e12, d1, d2, d3, d4
+    )
+
+    [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
+
+    # without padding, broadcast all possibilities for tensors of size i
+    final_tensor_constraint_no_padding = Conj(
+        [
+            *initialize_tensors_constraints,
+            generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4),
+        ]
+    )
+
+    # with padding, broadcast all possibilities for tensors of size i
+    final_tensor_constraint_padding_arg1, counter = apply_padding(
+        e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter
+    )
+
+    final_tensor_constraint_padding_arg2, counter = apply_padding(
+        e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter
+    )
+
+    return (
+        final_tensor_constraint_no_padding,
+        final_tensor_constraint_padding_arg1,
+        final_tensor_constraint_padding_arg2,
+        nat_dims_i,
+        counter,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f9f33965e07551c651fa560a80c5e263dd5b85
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
@@ -0,0 +1,446 @@
+# mypy: allow-untyped-defs
+from torch.fx.experimental.migrate_gradual_types.constraint import (
+    BinConstraintD,
+    BinConstraintT,
+    BVar,
+    Conj,
+    Disj,
+    DVar,
+    F,
+    is_algebraic_expression,
+    is_bool_expr,
+    is_dim,
+    Prod,
+    T,
+    TVar,
+)
+from torch.fx.experimental.migrate_gradual_types.constraint_generator import (
+    ConstraintGenerator,
+)
+from torch.fx.experimental.migrate_gradual_types.constraint_transformation import (
+    transform_constraint,
+)
+from torch.fx.experimental.migrate_gradual_types.operation import (
+    op_add,
+    op_div,
+    op_eq,
+    op_gt,
+    op_leq,
+    op_lt,
+    op_mod,
+    op_mul,
+    op_neq,
+    op_sub,
+)
+from torch.fx.tensor_type import Dyn, TensorType
+
+
+try:
+    import z3  # type: ignore[import]
+
+    from torch.fx.experimental.migrate_gradual_types.z3_types import (
+        D,
+        tensor_type,
+        z3_dyn,
+    )
+
+    HAS_Z3 = True
+
+    def transform_to_z3(constraint, counter, dimension_dict):
+        if isinstance(constraint, Conj):
+            conjuncts = []
+            for c in constraint.conjucts:
+                new_c, counter = transform_to_z3(c, counter, dimension_dict)
+                conjuncts.append(new_c)
+            return z3.And(conjuncts), counter
+
+        elif isinstance(constraint, Disj):
+            disjuncts = []
+            for c in constraint.disjuncts:
+                new_c, counter = transform_to_z3(c, counter, dimension_dict)
+                disjuncts.append(new_c)
+            return z3.Or(disjuncts), counter
+
+        elif isinstance(constraint, T):
+            return True, counter
+
+        elif isinstance(constraint, F):
+            return False, counter
+
+        elif isinstance(constraint, BinConstraintT):
+            if constraint.op == op_eq:
+                lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
+                rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
+                return (lhs == rhs), counter
+
+            else:
+                raise NotImplementedError("Method not yet implemented")
+
+        elif isinstance(constraint, BinConstraintD):
+            if constraint.op == op_eq:
+                if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
+                    transformed_rhs, counter = transform_to_z3(
+                        constraint.rhs, counter, dimension_dict
+                    )
+                    transformed_lhs = z3.Bool(constraint.lhs.c)
+                    return transformed_lhs == transformed_rhs, counter
+
+                elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
+                    # with dimension transformations we consider the encoding
+                    lhs, counter = transform_dimension(
+                        constraint.lhs, counter, dimension_dict
+                    )
+                    rhs, counter = transform_dimension(
+                        constraint.rhs, counter, dimension_dict
+                    )
+                    return lhs == rhs, counter
+
+                else:
+                    # then we have an algebraic expression which means that we disregard the
+                    # first element of the encoding
+                    lhs, counter = transform_algebraic_expression(
+                        constraint.lhs, counter, dimension_dict
+                    )
+                    rhs, counter = transform_algebraic_expression(
+                        constraint.rhs, counter, dimension_dict
+                    )
+                    return lhs == rhs, counter
+
+            # The assumption here is that the LHS and RHS must be dimensions
+            elif constraint.op == op_neq:
+                assert is_dim(constraint.lhs)
+                assert is_dim(constraint.rhs)
+                lhs, counter = transform_dimension(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_dimension(
+                    constraint.rhs, counter, dimension_dict
+                )
+                if constraint.rhs == Dyn or constraint.lhs == Dyn:
+                    if constraint.rhs == Dyn:
+                        return lhs.arg(0) == 1, counter
+                    elif constraint.lhs == Dyn:
+                        return rhs.arg(0) == 1, counter
+
+                # if one of the instances is a number
+                elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
+                    if isinstance(constraint.lhs, int):
+                        return (
+                            z3.Or(
+                                [
+                                    rhs.arg(0) == 0,
+                                    z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
+                                ]
+                            ),
+                            counter,
+                        )
+
+                    elif isinstance(constraint.rhs, int):
+                        return (
+                            z3.Or(
+                                [
+                                    lhs.arg(0) == 0,
+                                    z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
+                                ]
+                            ),
+                            counter,
+                        )
+
+                else:
+                    return (
+                        z3.Or(
+                            [
+                                z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
+                                z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
+                                z3.And(
+                                    [
+                                        lhs.arg(0) != 0,
+                                        rhs.arg(0) != 0,
+                                        lhs.arg(1) != rhs.arg(1),
+                                    ]
+                                ),
+                            ]
+                        ),
+                        counter,
+                    )
+
+            elif constraint.op == op_leq:
+                # if the dimensions are not dyn, this will come into effect
+                # there would have been another constraint specifying if a given dimension
+                # is dyn or not
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_algebraic_expression(
+                    constraint.rhs, counter, dimension_dict
+                )
+                return lhs <= rhs, counter
+
+            elif constraint.op == op_gt:
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_algebraic_expression(
+                    constraint.rhs, counter, dimension_dict
+                )
+                return lhs > rhs, counter
+
+            elif constraint.op == op_lt:
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_algebraic_expression(
+                    constraint.rhs, counter, dimension_dict
+                )
+                return lhs < rhs, counter
+
+            else:
+                raise NotImplementedError("operation not yet implemented")
+
+        else:
+            raise NotImplementedError("Operation not yet implemented")
+
+    def transform_var(tensor, counter, dimension_dict):
+        """
+        Transforms tensor variables to a format understood by z3
+        Args:
+            tensor: Tensor variable or a tensor type potentially with variable dimensions
+        Returns: Transformed variable to a z3 format
+
+        """
+        if isinstance(tensor, TensorType):
+            res = []
+            for t in tensor.__args__:
+                transformed, counter = transform_dimension(t, counter, dimension_dict)
+                res.append(transformed)
+
+            assert len(res) <= 4
+            if len(tensor.__args__) == 1:
+                return tensor_type.tensor1(res[0]), counter
+            elif len(tensor.__args__) == 2:
+                return tensor_type.tensor2(res[0], res[1]), counter
+            elif len(tensor.__args__) == 3:
+                return tensor_type.tensor3(res[0], res[1], res[2]), counter
+            elif len(tensor.__args__) == 4:
+                return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
+
+        elif tensor == Dyn:
+            return z3_dyn, counter
+
+        elif isinstance(tensor, TVar):
+            return z3.Const(tensor.tvar, tensor_type), counter
+
+    def transform_dimension(dimension, counter, dimension_dict):
+        """
+        Takes a dimension variable or a number and transforms it to a tuple
+        according to our scheme
+        Args:
+            dimension: The dimension to be transformed
+            counter: variable tracking
+
+        Returns:  tuple and the current counter
+
+        """
+        if dimension == Dyn:
+            counter += 1
+            return D(0, z3.Int(counter)), counter
+        elif isinstance(dimension, int):
+            return D(1, dimension), counter
+        elif isinstance(dimension, DVar):
+            if dimension.c in dimension_dict:
+                return (
+                    D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)),
+                    counter,
+                )
+            else:
+                counter += 1
+                dimension_dict[dimension.c] = counter
+                return D(z3.Int(counter), z3.Int(dimension.c)), counter
+
+    def transform_algebraic_expression(expr, counter, dimension_dict):
+        """
+        Transforms an algebraic expression to z3 format
+        Args:
+            expr: An expression is either a dimension variable or an algebraic-expression
+
+
+        Returns: the transformed expression
+
+        """
+        assert is_algebraic_expression(expr) or is_dim(expr)
+
+        if is_dim(expr):
+            transformed, counter = transform_dimension(expr, counter, dimension_dict)
+            return transformed.arg(1), counter
+
+        elif isinstance(expr, Prod):
+            dims = []
+            for dim in expr.products:
+                assert is_dim(dim)
+                d, counter = transform_dimension(dim, counter, dimension_dict)
+                dims.append(d.arg(1))
+            return z3.Product(dims), counter
+
+        elif is_algebraic_expression(expr):
+            lhs, counter = transform_algebraic_expression(
+                expr.lhs, counter, dimension_dict
+            )
+            rhs, counter = transform_algebraic_expression(
+                expr.rhs, counter, dimension_dict
+            )
+
+            if expr.op == op_sub:
+                c = lhs - rhs
+
+            elif expr.op == op_add:
+                c = lhs + rhs
+
+            elif expr.op == op_div:
+                c = lhs / rhs
+
+            elif expr.op == op_mul:
+                c = lhs * rhs
+
+            elif expr.op == op_mod:
+                c = lhs % rhs
+
+            else:
+                raise NotImplementedError("operation not yet implemented")
+
+            return c, counter
+
+        else:
+            raise RuntimeError
+
+    def transform_all_constraints(traced, counter=0):
+        """
+        Given a trace, generates constraints and transforms them to z3 format
+
+        """
+        dimension_dict = {}  # type: ignore[var-annotated]
+
+        generator = ConstraintGenerator(traced)
+        new_constraints, counter = generator.generate_constraints(counter)
+
+        # print(new_constraints.conjucts[0])
+        # print(*new_constraints.conjucts, sep='\n')
+
+        # transform precision, matching, consistency till obtaining a fixed point
+        new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
+        # print(new_constraints)
+        # print(new_constraints.conjucts)
+        # new_constraints.conjucts = new_constraints.conjucts[:-1]
+        # print(*new_constraints.conjucts, sep='\n')
+
+        transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
+        # print(transformed)
+        return transformed
+
+    def iterate_till_fixed_point(constraints, counter):
+        """
+        Transform constraints till reaching a fixed point
+        """
+        old_c = None
+        while old_c != constraints:
+            old_c = constraints
+            constraints, counter = transform_constraint(constraints, counter)
+        return constraints, counter
+
+    def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
+        """
+        Takes a node and a graph and generates two sets of constraints.
+        One set constraints the node's constraints and another set
+        constraints the negation of the node's constraints
+        Args:
+            tracer_root: the root for getting the module instances
+            graph: the graph so far in the tracing process
+            node: node that represents a conditional
+            counter: variable tracking
+
+        Returns: Two sets of constraints. One with a conjunction with the
+        the conditional constraint and the other with a conjunction with
+        its negation.
+
+        """
+        dimension_dict = {}  # type: ignore[var-annotated]
+
+        generator = ConstraintGenerator(tracer_root, graph)
+        new_constraints, counter = generator.generate_constraints(counter)
+
+        condition_constraint = new_constraints.conjucts[-1]
+
+        # we know the constraint is a conjunction where the last constraint is about the conditional
+        # so remove the last constraint
+        new_constraints.conjucts = new_constraints.conjucts[:-1]
+
+        # transform precision, matching, consistency till obtaining a fixed point
+        new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
+
+        # since the function returns a list of one element, we get the first element
+        # we are only interested in the RHS in this case because the LHS just stores
+        # the result
+
+        # we make sure the constraint is of the form:
+        # c = b where b is a boolean expression
+        # and we consider b (constraint.rhs) for transformation
+        assert isinstance(condition_constraint.lhs, BVar)
+        assert is_bool_expr(condition_constraint.rhs)
+        condition_constraint_rhs = condition_constraint.rhs
+
+        # transform the condition constraint
+        condition_constraint_rhs, counter = iterate_till_fixed_point(
+            condition_constraint_rhs, counter
+        )
+
+        transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
+
+        transformed_condition_constraint, counter = transform_to_z3(
+            condition_constraint_rhs, counter, dimension_dict
+        )
+
+        negation_transformed_condition_constraint = z3.Not(
+            transformed_condition_constraint
+        )
+
+        return z3.And([transformed, transformed_condition_constraint]), z3.And(
+            [transformed, negation_transformed_condition_constraint]
+        )
+
+    def evaluate_conditional_with_constraints(
+        tracer_root, graph, node, counter=0, user_constraints=None
+    ):
+        """
+        Given an IR and a node representing a conditional, evaluate the conditional
+        and its negation
+        Args:
+            tracer_root: Tracer root for module instances
+            node: The node to be evaluated
+
+        Returns: the results of evaluating the condition and the negation with
+        the rest of the constraints
+
+        """
+
+        (
+            transformed_positive,
+            transformed_negative,
+        ) = transform_all_constraints_trace_time(tracer_root, graph, node, counter)
+
+        s = z3.Solver()
+        s.add(transformed_positive)
+        if user_constraints is not None:
+            s.add(user_constraints)
+        condition = s.check()
+
+        s = z3.Solver()
+        s.add(transformed_negative)
+        if user_constraints is not None:
+            s.add(user_constraints)
+        negation = s.check()
+        return condition, negation
+
+except ImportError:
+    HAS_Z3 = False
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b160ec8de70f950db66cbe51d3657fbaf6b3aaf1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
@@ -0,0 +1,58 @@
+from torch.fx.experimental.migrate_gradual_types.constraint import (
+    BinConstraintD,
+    BVar,
+    DVar,
+    TVar,
+)
+from torch.fx.experimental.migrate_gradual_types.operation import op_leq
+
+
+def gen_tvar(curr: int) -> tuple[TVar, int]:
+    """
+    Generate a tensor variable
+    :param curr: The current counter
+    :return: a tensor variable and the updated counter
+    """
+    curr += 1
+    return TVar(curr), curr
+
+
+def gen_dvar(curr: int) -> tuple[DVar, int]:
+    """
+    Generate a dimension variable
+    :param curr: the current counter
+    :return: a dimension variable and an updated counter
+    """
+    curr += 1
+    return DVar(curr), curr
+
+
+def gen_bvar(curr: int) -> tuple[BVar, int]:
+    """
+    Generate a boolean variable
+    :param curr: the current counter
+    :return: a boolean variable and an updated counter
+    """
+    curr += 1
+    return BVar(curr), curr
+
+
+def gen_tensor_dims(n: int, curr: int) -> tuple[list[DVar], int]:
+    """
+    Generate a list of tensor dimensions
+    :param n:  the number of dimensions
+    :param curr: the current counter
+    :return: a list of dimension variables and an updated counter
+    """
+    dims = []
+    for _ in range(n):
+        dvar, curr = gen_dvar(curr)
+        dims.append(dvar)
+    return dims, curr
+
+
+def gen_nat_constraints(list_of_dims: list[DVar]) -> list[BinConstraintD]:
+    """
+    Generate natural number constraints for dimensions
+    """
+    return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..17b77f6396206e37e51bbb1ff68479b55bc062fd
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py
@@ -0,0 +1,59 @@
+import operator
+
+import torch
+
+
+def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
+    """
+    Annotate the type of getitem nodes, inferred from the type of sequence node.
+    If sequence node is not annotated with a type, do nothing.
+    Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
+
+    This is helpful since annotations on local names within function are lost during FX transforms.
+    Adding back known type annotation for getitem nodes to improve jit scriptability.
+
+    Args:
+        graph (Graph): The graph to be annotated
+    """
+    for node in graph.nodes:
+        if node.target is operator.getitem:
+            sequence_node, index_node = node.args
+            if not sequence_node.type:
+                continue
+            # container types
+            if hasattr(sequence_node.type, "_name"):
+                parameterized_types = sequence_node.type.__args__
+                if sequence_node.type._name == "Tuple":
+                    if len(parameterized_types) == 2 and isinstance(
+                        parameterized_types[1], type(...)
+                    ):
+                        node.type = parameterized_types[0]
+                    else:
+                        assert len(parameterized_types) > index_node
+                        node_type = parameterized_types[index_node]
+                        node.type = node_type
+                elif sequence_node.type._name == "List":
+                    assert len(parameterized_types) == 1
+                    node.type = parameterized_types[0]
+            # Generic Alias Type
+            elif hasattr(sequence_node.type, "__origin__"):
+                parameterized_types = sequence_node.type.__args__
+                if sequence_node.type.__origin__ is tuple:
+                    if len(parameterized_types) == 2 and isinstance(
+                        parameterized_types[1], type(...)
+                    ):
+                        node.type = parameterized_types[0]
+                    else:
+                        assert len(parameterized_types) > index_node
+                        node_type = parameterized_types[index_node]
+                        node.type = node_type
+                elif sequence_node.type.__origin__ is list:
+                    assert len(parameterized_types) == 1
+                    node.type = parameterized_types[0]
+            # NamedTuple type
+            elif hasattr(sequence_node.type, "__annotations__"):
+                if sequence_node.type == torch.Tensor:
+                    continue
+                sequence_node_field_types = sequence_node.type.__annotations__
+                field_name = sequence_node.type._fields[index_node]
+                node.type = sequence_node_field_types[field_name]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ce645df8fa92e03e912da7d66f9b8622edeec7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py
@@ -0,0 +1,504 @@
+# mypy: allow-untyped-defs
+
+import hashlib
+from itertools import chain
+from types import ModuleType
+from typing import Any, Optional, TYPE_CHECKING
+
+import torch
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import _parse_stack_trace
+from torch.fx.node import _format_arg, _get_qualified_name
+from torch.fx.operator_schemas import normalize_function
+from torch.fx.passes.shape_prop import TensorMetadata
+
+
+if TYPE_CHECKING:
+    import pydot
+
+    HAS_PYDOT = True
+else:
+    pydot: Optional[ModuleType]
+    try:
+        import pydot
+
+        HAS_PYDOT = True
+    except ModuleNotFoundError:
+        HAS_PYDOT = False
+        pydot = None
+
+
+__all__ = ["FxGraphDrawer"]
+
+_COLOR_MAP = {
+    "placeholder": '"AliceBlue"',
+    "call_module": "LemonChiffon1",
+    "get_param": "Yellow2",
+    "get_attr": "LightGrey",
+    "output": "PowderBlue",
+}
+
+_HASH_COLOR_MAP = [
+    "CadetBlue1",
+    "Coral",
+    "DarkOliveGreen1",
+    "DarkSeaGreen1",
+    "GhostWhite",
+    "Khaki1",
+    "LavenderBlush1",
+    "LightSkyBlue",
+    "MistyRose1",
+    "MistyRose2",
+    "PaleTurquoise2",
+    "PeachPuff1",
+    "Salmon",
+    "Thistle1",
+    "Thistle3",
+    "Wheat1",
+]
+
+_WEIGHT_TEMPLATE = {
+    "fillcolor": "Salmon",
+    "style": '"filled,rounded"',
+    "fontcolor": "#000000",
+}
+
+if HAS_PYDOT:
+
+    @compatibility(is_backward_compatible=False)
+    class FxGraphDrawer:
+        """
+        Visualize a torch.fx.Graph with graphviz
+        Basic usage:
+            g = FxGraphDrawer(symbolic_traced, "resnet18")
+            g.get_dot_graph().write_svg("a.svg")
+        """
+
+        def __init__(
+            self,
+            graph_module: torch.fx.GraphModule,
+            name: str,
+            ignore_getattr: bool = False,
+            ignore_parameters_and_buffers: bool = False,
+            skip_node_names_in_args: bool = True,
+            parse_stack_trace: bool = False,
+            dot_graph_shape: Optional[str] = None,
+            normalize_args: bool = False,
+        ):
+            self._name = name
+            self.dot_graph_shape = (
+                dot_graph_shape if dot_graph_shape is not None else "record"
+            )
+            self.normalize_args = normalize_args
+            _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
+
+            self._dot_graphs = {
+                name: self._to_dot(
+                    graph_module,
+                    name,
+                    ignore_getattr,
+                    ignore_parameters_and_buffers,
+                    skip_node_names_in_args,
+                    parse_stack_trace,
+                )
+            }
+
+            for node in graph_module.graph.nodes:
+                if node.op != "call_module":
+                    continue
+
+                leaf_node = self._get_leaf_node(graph_module, node)
+
+                if not isinstance(leaf_node, torch.fx.GraphModule):
+                    continue
+
+                self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
+                    leaf_node,
+                    f"{name}_{node.target}",
+                    ignore_getattr,
+                    ignore_parameters_and_buffers,
+                    skip_node_names_in_args,
+                    parse_stack_trace,
+                )
+
+        def get_dot_graph(self, submod_name=None) -> pydot.Dot:
+            """
+            Visualize a torch.fx.Graph with graphviz
+            Example:
+                >>> # xdoctest: +REQUIRES(module:pydot)
+                >>> # xdoctest: +REQUIRES(module:ubelt)
+                >>> # define module
+                >>> class MyModule(torch.nn.Module):
+                >>>     def __init__(self) -> None:
+                >>>         super().__init__()
+                >>>         self.linear = torch.nn.Linear(4, 5)
+                >>>     def forward(self, x):
+                >>>         return self.linear(x).clamp(min=0.0, max=1.0)
+                >>> module = MyModule()
+                >>> # trace the module
+                >>> symbolic_traced = torch.fx.symbolic_trace(module)
+                >>> # setup output file
+                >>> import ubelt as ub
+                >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir()
+                >>> fpath = dpath / "linear.svg"
+                >>> # draw the graph
+                >>> g = FxGraphDrawer(symbolic_traced, "linear")
+                >>> g.get_dot_graph().write_svg(fpath)
+            """
+            if submod_name is None:
+                return self.get_main_dot_graph()
+            else:
+                return self.get_submod_dot_graph(submod_name)
+
+        def get_main_dot_graph(self) -> pydot.Dot:
+            return self._dot_graphs[self._name]
+
+        def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
+            return self._dot_graphs[f"{self._name}_{submod_name}"]
+
+        def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
+            return self._dot_graphs
+
+        def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
+            template = {
+                "shape": self.dot_graph_shape,
+                "fillcolor": "#CAFFE3",
+                "style": '"filled,rounded"',
+                "fontcolor": "#000000",
+            }
+            if node.op in _COLOR_MAP:
+                template["fillcolor"] = _COLOR_MAP[node.op]
+            else:
+                # Use a random color for each node; based on its name so it's stable.
+                target_name = node._pretty_print_target(node.target)
+                target_hash = int(
+                    hashlib.md5(
+                        target_name.encode(), usedforsecurity=False
+                    ).hexdigest()[:8],
+                    16,
+                )
+                template["fillcolor"] = _HASH_COLOR_MAP[
+                    target_hash % len(_HASH_COLOR_MAP)
+                ]
+            return template
+
+        def _get_leaf_node(
+            self, module: torch.nn.Module, node: torch.fx.Node
+        ) -> torch.nn.Module:
+            py_obj = module
+            assert isinstance(node.target, str)
+            atoms = node.target.split(".")
+            for atom in atoms:
+                if not hasattr(py_obj, atom):
+                    raise RuntimeError(
+                        str(py_obj) + " does not have attribute " + atom + "!"
+                    )
+                py_obj = getattr(py_obj, atom)
+            return py_obj
+
+        def _typename(self, target: Any) -> str:
+            if isinstance(target, torch.nn.Module):
+                ret = torch.typename(target)
+            elif isinstance(target, str):
+                ret = target
+            else:
+                ret = _get_qualified_name(target)
+
+            # Escape "{" and "}" to prevent dot files like:
+            # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
+            # which triggers `Error: bad label format (...)` from dot
+            return ret.replace("{", r"\{").replace("}", r"\}")
+
+        # shorten path to avoid drawing long boxes
+        # for full path = '/home/weif/pytorch/test.py'
+        # return short path = 'pytorch/test.py'
+        def _shorten_file_name(
+            self,
+            full_file_name: str,
+            truncate_to_last_n: int = 2,
+        ):
+            splits = full_file_name.split("/")
+            if len(splits) >= truncate_to_last_n:
+                return "/".join(splits[-truncate_to_last_n:])
+            return full_file_name
+
+        def _get_node_label(
+            self,
+            module: torch.fx.GraphModule,
+            node: torch.fx.Node,
+            skip_node_names_in_args: bool,
+            parse_stack_trace: bool,
+        ) -> str:
+            def _get_str_for_args_kwargs(arg):
+                if isinstance(arg, tuple):
+                    prefix, suffix = r"|args=(\l", r",\n)\l"
+                    arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
+                elif isinstance(arg, dict):
+                    prefix, suffix = r"|kwargs={\l", r",\n}\l"
+                    arg_strs_list = [
+                        f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items()
+                    ]
+                else:  # Fall back to nothing in unexpected case.
+                    return ""
+
+                # Strip out node names if requested.
+                if skip_node_names_in_args:
+                    arg_strs_list = [a for a in arg_strs_list if "%" not in a]
+                if len(arg_strs_list) == 0:
+                    return ""
+                arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
+                if len(arg_strs_list) == 1:
+                    arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
+                return arg_strs.replace("{", r"\{").replace("}", r"\}")
+
+            label = "{" + f"name=%{node.name}|op_code={node.op}\n"
+
+            if node.op == "call_module":
+                leaf_module = self._get_leaf_node(module, node)
+                label += r"\n" + self._typename(leaf_module) + r"\n|"
+                extra = ""
+                if hasattr(leaf_module, "__constants__"):
+                    extra = r"\n".join(
+                        [
+                            f"{c}: {getattr(leaf_module, c)}"
+                            for c in leaf_module.__constants__  # type: ignore[union-attr]
+                        ]  # type: ignore[union-attr]
+                    )
+                label += extra + r"\n"
+            else:
+                label += f"|target={self._typename(node.target)}" + r"\n"
+                if self.normalize_args:
+                    try:
+                        args, kwargs = normalize_function(  # type: ignore[misc]
+                            node.target,  # type: ignore[arg-type]
+                            node.args,  # type: ignore[arg-type]
+                            node.kwargs,
+                            normalize_to_only_use_kwargs=True,
+                        )
+                    except Exception:
+                        # Fallback to not normalizing if there's an exception.
+                        # Some functions need overloads specified to normalize.
+                        args, kwargs = node.args, node.kwargs
+                else:
+                    args, kwargs = node.args, node.kwargs
+                if len(args) > 0:
+                    label += _get_str_for_args_kwargs(args)
+                if len(kwargs) > 0:
+                    label += _get_str_for_args_kwargs(kwargs)
+                label += f"|num_users={len(node.users)}" + r"\n"
+
+            tensor_meta = node.meta.get("tensor_meta")
+            label += self._tensor_meta_to_label(tensor_meta)
+
+            # for original fx graph
+            # print buf=buf0, n_origin=6
+            buf_meta = node.meta.get("buf_meta", None)
+            if buf_meta is not None:
+                label += f"|buf={buf_meta.name}" + r"\n"
+                label += f"|n_origin={buf_meta.n_origin}" + r"\n"
+
+            # for original fx graph
+            # print file:lineno code
+            if parse_stack_trace and node.stack_trace is not None:
+                parsed_stack_trace = _parse_stack_trace(node.stack_trace)
+                fname = self._shorten_file_name(parsed_stack_trace.file)
+                label += (
+                    f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}"
+                    + r"\n"
+                )
+
+            return label + "}"
+
+        def _tensor_meta_to_label(self, tm) -> str:
+            if tm is None:
+                return ""
+            elif isinstance(tm, TensorMetadata):
+                return self._stringify_tensor_meta(tm)
+            elif isinstance(tm, list):
+                result = ""
+                for item in tm:
+                    result += self._tensor_meta_to_label(item)
+                return result
+            elif isinstance(tm, dict):
+                result = ""
+                for v in tm.values():
+                    result += self._tensor_meta_to_label(v)
+                return result
+            elif isinstance(tm, tuple):
+                result = ""
+                for item in tm:
+                    result += self._tensor_meta_to_label(item)
+                return result
+            else:
+                raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
+
+        def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
+            result = ""
+            if not hasattr(tm, "dtype"):
+                print("tm", tm)
+            result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
+            result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
+            result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
+            result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
+            if tm.is_quantized:
+                assert tm.qparams is not None
+                assert "qscheme" in tm.qparams
+                qscheme = tm.qparams["qscheme"]
+                if qscheme in {
+                    torch.per_tensor_affine,
+                    torch.per_tensor_symmetric,
+                }:
+                    result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
+                    result += (
+                        "|"
+                        + "q_zero_point"
+                        + "="
+                        + str(tm.qparams["zero_point"])
+                        + r"\n"
+                    )
+                elif qscheme in {
+                    torch.per_channel_affine,
+                    torch.per_channel_symmetric,
+                    torch.per_channel_affine_float_qparams,
+                }:
+                    result += (
+                        "|"
+                        + "q_per_channel_scale"
+                        + "="
+                        + str(tm.qparams["scale"])
+                        + r"\n"
+                    )
+                    result += (
+                        "|"
+                        + "q_per_channel_zero_point"
+                        + "="
+                        + str(tm.qparams["zero_point"])
+                        + r"\n"
+                    )
+                    result += (
+                        "|"
+                        + "q_per_channel_axis"
+                        + "="
+                        + str(tm.qparams["axis"])
+                        + r"\n"
+                    )
+                else:
+                    raise RuntimeError(f"Unsupported qscheme: {qscheme}")
+                result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
+            return result
+
+        def _get_tensor_label(self, t: torch.Tensor) -> str:
+            return str(t.dtype) + str(list(t.shape)) + r"\n"
+
+        # when parse_stack_trace=True
+        # print file:lineno code
+        def _to_dot(
+            self,
+            graph_module: torch.fx.GraphModule,
+            name: str,
+            ignore_getattr: bool,
+            ignore_parameters_and_buffers: bool,
+            skip_node_names_in_args: bool,
+            parse_stack_trace: bool,
+        ) -> pydot.Dot:
+            """
+            Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
+            If ignore_parameters_and_buffers is True, the parameters and buffers
+            created with the module will not be added as nodes and edges.
+            """
+
+            # "TB" means top-to-bottom rank direction in layout
+            dot_graph = pydot.Dot(name, rankdir="TB")
+
+            buf_name_to_subgraph = {}
+
+            for node in graph_module.graph.nodes:
+                if ignore_getattr and node.op == "get_attr":
+                    continue
+
+                style = self._get_node_style(node)
+                dot_node = pydot.Node(
+                    node.name,
+                    label=self._get_node_label(
+                        graph_module, node, skip_node_names_in_args, parse_stack_trace
+                    ),
+                    **style,  # type: ignore[arg-type]
+                )
+
+                current_graph = dot_graph
+
+                buf_meta = node.meta.get("buf_meta", None)
+                if buf_meta is not None and buf_meta.n_origin > 1:
+                    buf_name = buf_meta.name
+                    if buf_name not in buf_name_to_subgraph:
+                        buf_name_to_subgraph[buf_name] = pydot.Cluster(
+                            buf_name, label=buf_name
+                        )
+                    current_graph = buf_name_to_subgraph.get(buf_name)  # type: ignore[assignment]
+
+                # pyrefly: ignore [missing-attribute]
+                current_graph.add_node(dot_node)
+
+                def get_module_params_or_buffers():
+                    for pname, ptensor in chain(
+                        leaf_module.named_parameters(),
+                        # pyrefly: ignore [bad-argument-type]
+                        leaf_module.named_buffers(),
+                    ):
+                        pname1 = node.name + "." + pname
+                        label1 = (
+                            pname1 + "|op_code=get_" + "parameter"
+                            if isinstance(ptensor, torch.nn.Parameter)
+                            else "buffer" + r"\l"
+                        )
+                        dot_w_node = pydot.Node(
+                            pname1,
+                            label="{" + label1 + self._get_tensor_label(ptensor) + "}",
+                            **_WEIGHT_TEMPLATE,  # type: ignore[arg-type]
+                        )
+                        dot_graph.add_node(dot_w_node)
+                        dot_graph.add_edge(pydot.Edge(pname1, node.name))
+
+                if node.op == "call_module":
+                    leaf_module = self._get_leaf_node(graph_module, node)
+
+                    if not ignore_parameters_and_buffers and not isinstance(
+                        leaf_module, torch.fx.GraphModule
+                    ):
+                        get_module_params_or_buffers()
+
+            for subgraph in buf_name_to_subgraph.values():
+                subgraph.set("color", "royalblue")
+                subgraph.set("penwidth", "2")
+                dot_graph.add_subgraph(subgraph)  # type: ignore[arg-type]
+
+            for node in graph_module.graph.nodes:
+                if ignore_getattr and node.op == "get_attr":
+                    continue
+
+                for user in node.users:
+                    dot_graph.add_edge(pydot.Edge(node.name, user.name))
+
+            return dot_graph
+
+else:
+    if not TYPE_CHECKING:
+
+        @compatibility(is_backward_compatible=False)
+        class FxGraphDrawer:
+            def __init__(
+                self,
+                graph_module: torch.fx.GraphModule,
+                name: str,
+                ignore_getattr: bool = False,
+                ignore_parameters_and_buffers: bool = False,
+                skip_node_names_in_args: bool = True,
+                parse_stack_trace: bool = False,
+                dot_graph_shape: Optional[str] = None,
+                normalize_args: bool = False,
+            ):
+                raise RuntimeError(
+                    "FXGraphDrawer requires the pydot package to be installed. Please install "
+                    "pydot through your favorite Python package manager."
+                )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e21668d8ff17d122bc8ba4464682514af43e7bff
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5409f614df8d343aeb371b56c6e0a7d79789fc1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a4376c253a5043caf8326eaa7610e332a791785
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00f5785a95146f78d5fdab11e62d5f04a23dc222
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e17a8040e6a9573200e10bb1fa670bb71219a26
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py
@@ -0,0 +1,97 @@
+from collections.abc import Callable
+from typing import Any
+
+import torch
+import torch.nn as nn
+from torch.fx._compatibility import compatibility
+from torch.fx.graph_module import GraphModule
+
+
+__all__ = [
+    "default_matching",
+    "extract_attrs_for_lowering",
+    "lift_lowering_attrs_to_nodes",
+]
+
+
+# Matching method matches the attribute name of current version to the attribute name of `target_version`
+@compatibility(is_backward_compatible=False)
+def default_matching(name: str, target_version: int) -> str:
+    """Default matching method"""
+    return name
+
+
+# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
+# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
+# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
+module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = {
+    torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
+    torch.nn.modules.conv.Conv2d: (
+        1,
+        [
+            "weight",
+            "bias",
+            "kernel_size",
+            "stride",
+            "padding",
+            "dilation",
+            "groups",
+            "padding_mode",
+        ],
+        default_matching,
+    ),
+    torch.nn.modules.batchnorm.BatchNorm2d: (
+        2,
+        ["weight", "bias", "running_mean", "running_var", "eps"],
+        default_matching,
+    ),
+    torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
+    torch.nn.modules.pooling.MaxPool2d: (
+        1,
+        ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"],
+        default_matching,
+    ),
+    torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
+}
+
+
+@compatibility(is_backward_compatible=False)
+def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]:
+    """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
+    after checking module's version is compatible with the `module_fetch_book`.
+    """
+    attrs_for_lowering: dict[str, Any] = {}
+    attrs_for_lowering["name"] = torch.typename(mod)
+
+    if type(mod) in module_fetch_book:
+        version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
+        if version < mod._version:
+            raise RuntimeError(
+                f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
+                "please upgrade the module_fetch_book, open an issue and @842974287 "
+                "or report a bug to AIACC team directly."
+            )
+        for attr in param_to_fetch:
+            attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
+    else:
+        raise RuntimeError(
+            f"{torch.typename(mod)} is not in the module_fetch_book yet, "
+            "please add it to the module_fetch_book, open an issue and @842974287 "
+            "or report a bug to AIACC team directly."
+        )
+    return attrs_for_lowering
+
+
+@compatibility(is_backward_compatible=False)
+def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
+    """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module."""
+    submodules = dict(fx_module.named_modules())
+
+    for node in fx_module.graph.nodes:
+        if node.op == "call_module":
+            if isinstance(submodules[node.target], GraphModule):
+                lift_lowering_attrs_to_nodes(submodules[node.target])
+            else:
+                node.attrs_for_lowering = extract_attrs_for_lowering(
+                    submodules[node.target]
+                )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py
new file mode 100644
index 0000000000000000000000000000000000000000..e475a5bc9b6df55dc640d80dc6510ed263368c4a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py
@@ -0,0 +1,658 @@
+# mypy: allow-untyped-defs
+import functools
+import logging
+import operator
+import sys
+from typing import Any, Optional, TYPE_CHECKING
+
+
+# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
+if TYPE_CHECKING:
+    import sympy
+
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+else:
+    ShapeEnv = Any
+
+import torch
+import torch.utils._pytree as pytree
+from torch import fx
+from torch._subclasses.meta_utils import is_sparse_any
+from torch.fx._compatibility import compatibility
+from torch.fx._utils import lazy_format_graph_code
+from torch.fx.experimental.proxy_tensor import py_sym_types
+from torch.fx.experimental.sym_node import SymNode
+from torch.fx.graph_module import GraphModule
+
+
+__all__ = ["insert_deferred_runtime_asserts"]
+
+log = logging.getLogger(__name__)
+graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose")
+
+
+def _get_example_value(node: fx.Node) -> Optional[str]:
+    """
+    Get the example value key for a node, since dynamo uses "example_value"
+    while non-strict export uses "val.
+    """
+    if "example_value" in node.meta:
+        return node.meta["example_value"]
+    elif "val" in node.meta:
+        return node.meta["val"]
+    else:
+        return None
+
+
+def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
+    val = _get_example_value(node)
+    if isinstance(val, py_sym_types):
+        return val.node.expr
+    return None
+
+
+@compatibility(is_backward_compatible=True)
+def insert_deferred_runtime_asserts(
+    gm: GraphModule,
+    shape_env: ShapeEnv,
+    name: str,
+    export: bool = False,
+) -> None:
+    """
+    During tracing, we may have discovered that some data-dependent values
+    had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime
+    that x.item() >= 0.  These asserts can happen unpredictably during fake
+    tensor propagation, so we cannot conveniently insert them into the FX graph
+    when they occur.  Instead, we accumulate them in the ShapeEnv, and in this
+    pass insert them into the graph as proper tests.
+
+    This pass also deduplicates size-related computation, CSE-ing ops that produce
+    symbolic values and/or are involved in runtime asserts. Additionally, shape calls
+    (size/stride/storage_offset) are turned into compute on input sizes if possible,
+    allowing intermediate tensors to be freed earlier. For example, here dynamo will
+    DCE the cat and repeat calls:
+
+        z = torch.cat([x, x], dim=0)  # 2*s0
+        w = z.repeat(y.shape[0])  # 2*s0*s1
+        _w = w.shape[0]
+        # something with _w, but not w ...
+
+        # turns into ->
+        _w0 = 2 * s0
+        _w = _w0 * s1
+
+        # where s0, s1 are either SymInt graph inputs, or the result of added size calls
+
+    Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert
+    the same expression, and redundant constrain_range calls are also deduplicated.
+    Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate
+    information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol,
+    and we delete all previous calls, adding bound checks at the end of this pass.
+    """
+
+    # Import sympy locally
+    import sympy
+
+    from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
+    from torch.fx.experimental.symbolic_shapes import (
+        _get_placeholder_expr,
+        _has_uninterpretable_sympy_function,
+        CallMethodKey,
+        cast_symbool_to_symint_guardless,
+        ConvertIntKey,
+        DivideByKey,
+        free_symbols,
+        InnerTensorKey,
+        resolve_unbacked_bindings,
+    )
+    from torch.utils._sympy.numbers import int_oo
+    from torch.utils._sympy.reference import (
+        OptimizedPythonReferenceAnalysis,
+        PythonReferenceAnalysis,
+    )
+    from torch.utils._sympy.value_ranges import ValueRanges
+
+    # TODO: Request simplification on runtime asserts before emitting them
+    ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
+    graph = gm.graph
+    tracer = fx.proxy.GraphAppendingTracer(graph)
+    graph_code_log.debug(
+        "%s",
+        lazy_format_graph_code(
+            f"pre insert_deferred_runtime_asserts {name}", gm, colored=True
+        ),
+    )
+
+    # We are going to mutate the dict
+    expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {}
+    placeholders = set()
+    first_non_placeholder = None
+    for node in graph.nodes:
+        if node.op != "placeholder":
+            first_non_placeholder = node
+            break
+        else:
+            placeholders.add(node)
+
+    def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool:
+        """
+        If a size/stride/storage offset call on an intermediate tensor,
+        we can try to compute the value from input shapes instead.
+        """
+        return (
+            (val := _get_sym_val(node)) is not None
+            and not isinstance(val, sympy.Number)
+            # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
+            and not _has_uninterpretable_sympy_function(val)
+            and any(
+                isinstance(arg, fx.Node)
+                and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
+                and arg.op != "placeholder"
+                for arg in node.args
+            )
+        )
+
+    # Figure out what key to use, val or example_value
+    val_key = "val"
+    for node in graph.nodes:
+        if "example_value" in node.meta:
+            val_key = "example_value"
+            break
+        elif "val" in node.meta:
+            break
+
+    def _node_metadata_hook(
+        node: torch.fx.Node,
+        stack_trace: Optional[str] = None,
+        nn_module_stack: Optional[dict[str, Any]] = None,
+        custom: Optional[dict[str, Any]] = None,
+    ) -> None:
+        fake_args = pytree.tree_map(
+            lambda arg: (
+                _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
+            ),
+            node.args,
+        )
+        try:
+            target = node.target
+            if node.op == "call_method":
+                assert isinstance(node.target, str)
+                target = getattr(fake_args[0], node.target)
+                fake_args = fake_args[1:]
+            node.meta[val_key] = target(*fake_args)  # type: ignore[operator]
+        except NotImplementedError:
+            # This can happen when attempting to reify a symbol with an unsupported call_function node,
+            # e.g. with NestedTensors + sym_size.int via match_symbol().
+            # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input.
+            pass
+        if stack_trace is not None:
+            node.meta["stack_trace"] = stack_trace
+        if nn_module_stack is not None:
+            node.meta["nn_module_stack"] = nn_module_stack
+        if custom is not None:
+            node.meta["custom"] = custom
+
+    # Track asserts/checks we've added
+    added_asserts: set[sympy.Expr] = set()
+    constrained_unbacked_symbols: set[sympy.Symbol] = set()
+
+    Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
+
+    def _sympy_interp(expr_to_proxy, expr):
+        # sympy_interp() with hash consing
+        from sympy import Integer, Number, Symbol
+        from sympy.logic.boolalg import BooleanAtom
+
+        from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
+
+        # hash cons
+        if expr in expr_to_proxy:
+            return expr_to_proxy[expr]
+        # base cases, don't cache
+        if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
+            return sympy_interp(Analysis, expr_to_proxy, expr)
+
+        # hash cons on arguments, run expr handler
+        expr_to_proxy[expr] = _run_sympy_handler(
+            Analysis,
+            [_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
+            expr,
+        )
+        return expr_to_proxy[expr]
+
+    def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool:
+        # This is probably unnecessary, but since torch._check() calls for single-symbol bounds
+        # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant
+        # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial.
+        if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan):
+            return False
+        lhs, rhs = expr.args
+        return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or (
+            isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number)
+        )
+
+    def add_runtime_asserts(ras):
+        for ra in ras:
+            if (
+                # redundant
+                ra.expr in added_asserts
+                # if we've already added a constrain_range call for this symbol,
+                # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant.
+                or (
+                    len(ra.expr.free_symbols) == 1
+                    and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols
+                    and _is_bound_expr_for_symbol(ra.expr)
+                )
+                # don't try to reify sympy functions we can't turn into FX nodes
+                or _has_uninterpretable_sympy_function(ra.expr)
+            ):
+                continue
+
+            log.debug("inserting runtime assert %s", ra.expr)
+            # Need to process ALL free symbols, not just unbacked ones
+            fvs = free_symbols(ra.expr)
+            missing = fvs - expr_to_proxy.keys()
+            if missing:
+                i1 = min(missing, key=str)
+                # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689
+                # assert shape_env.is_unbacked_symint(i1), i1
+                ras_by_symbol.setdefault(i1, []).append(ra)
+            else:
+                # Convert the sympy expression into a sequence of FX
+                # nodes
+                with _set_node_metadata_hook(gm, _node_metadata_hook):
+                    res = _sympy_interp(expr_to_proxy, ra.expr).node
+
+                    graph.call_function(
+                        torch.ops.aten._assert_scalar.default,
+                        # TODO: use ra.msg here, but it's pretty
+                        # useless right now
+                        (
+                            res,
+                            f"Runtime assertion failed for expression {ra.expr} on node '{res}'",
+                        ),
+                    )
+                added_asserts.add(ra.expr)
+
+    nodes = list(graph.nodes)
+    for i, node in enumerate(nodes[:-1]):
+        # Placeholders can match symbols, but when we destructure them
+        # with size we have to make sure we insert the nodes after all
+        # the placeholders
+        with graph.inserting_before(
+            nodes[i + 1] if node not in placeholders else first_non_placeholder
+        ):
+            # Unfortunately, this logic still must remain because manual
+            # make_fx calls may not explicitly bind all symbolic ints as
+            # arguments to the function, so we must infer it from the other
+            # arguments
+            if (
+                node in placeholders
+                and (example_value := _get_example_value(node)) is not None
+            ):
+
+                def match_symbol(symint, cb):
+                    if (
+                        isinstance(symint, torch.SymInt)
+                        and isinstance(symint.node, SymNode)
+                        and isinstance(
+                            s := _get_placeholder_expr(symint.node), sympy.Symbol
+                        )
+                        and s not in expr_to_proxy
+                    ):
+                        with _set_node_metadata_hook(gm, _node_metadata_hook):
+                            expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
+
+                        log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
+
+                match_symbol(example_value, lambda: node)
+
+                if isinstance(t := example_value, torch.Tensor):
+                    for i, s in enumerate(t.size()):
+                        match_symbol(
+                            s,
+                            lambda: graph.call_function(
+                                torch.ops.aten.sym_size.int, (node, i)
+                            ),
+                        )
+                    if not is_sparse_any(t):
+                        for i, s in enumerate(t.stride()):
+                            match_symbol(
+                                s,
+                                lambda: graph.call_function(
+                                    torch.ops.aten.sym_stride.int, (node, i)
+                                ),
+                            )
+                        match_symbol(
+                            t.storage_offset(),
+                            lambda: graph.call_function(
+                                torch.ops.aten.sym_storage_offset.default, (node,)
+                            ),
+                        )
+
+            # Handle asserts that aren't associated with any symbol.  This
+            # doesn't really have to be in the loop as it will only run once,
+            # it just needs to happen right after the placeholders.
+            # insert this after placeholders & added sym nodes, and before non-placeholders.
+            if node == first_non_placeholder:
+                add_runtime_asserts(ras_by_symbol.pop(None, []))  # type: ignore[call-overload]
+
+            # deduplicate asserts already present in graph, and remove trivial asserts
+            if node.target in (
+                torch._check,
+                torch.ops.aten._assert_scalar.default,
+            ):
+                cond = node.args[0] if node.args else node.kwargs.get("cond")
+                if (
+                    cond == True  # noqa: E712
+                    or (assert_expr := _get_sym_val(cond)) in expr_to_proxy
+                    and assert_expr in added_asserts
+                ):
+                    arg = cond
+                    gm.graph.erase_node(node)
+                    if isinstance(arg, fx.Node) and not arg.users:
+                        gm.graph.erase_node(arg)
+                else:
+                    added_asserts.add(assert_expr)  # type: ignore[arg-type]
+
+            # hash cons, replace function calls that return torch.SymInts with direct references to
+            # FX nodes built up to reify the sympy expression.
+            if (
+                node.op != "placeholder"
+                and (sym_expr := _get_sym_val(node)) is not None
+            ):
+                # this guards against deleting calls like item() that produce new untracked symbols
+                def has_new_untracked_symbols():
+                    # pyrefly: ignore [missing-attribute]
+                    for symbol in sym_expr.free_symbols:
+                        if symbol not in expr_to_proxy:
+                            return True
+                    return False
+
+                # this guards against deleting calls that produce unbacked bindings we haven't yet seen.
+                # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
+                # (is backed), but produces an unbacked symbol. In this case keep the node alive.
+                resolved_unbacked_bindings = resolve_unbacked_bindings(
+                    shape_env, node.meta.get("unbacked_bindings", {})
+                )
+
+                def has_new_unbacked_bindings():
+                    assert resolved_unbacked_bindings is not None
+                    for key in resolved_unbacked_bindings:
+                        if key not in expr_to_proxy:
+                            return True
+                    return False
+
+                # maybe re-reify expression, replace current node
+                if (
+                    sym_expr in expr_to_proxy
+                    or (  # example value is redundant
+                        _is_intermediate_tensor_sym_call(node)
+                        # shape call on intermediate tensor, turn into computation on input shapes
+                        and not has_new_untracked_symbols()
+                    )
+                ) and not has_new_unbacked_bindings():
+                    if _is_intermediate_tensor_sym_call(
+                        node
+                    ):  # reify from input shapes
+                        with _set_node_metadata_hook(
+                            gm,
+                            functools.partial(
+                                _node_metadata_hook,
+                                stack_trace=node.meta.get("stack_trace"),
+                                nn_module_stack=node.meta.get("nn_module_stack"),
+                            ),
+                        ):
+                            expr_to_proxy[sym_expr] = _sympy_interp(
+                                expr_to_proxy,
+                                sym_expr,
+                            )  # type: ignore[arg-type]
+                        # won't try DCE-ing tensor compute here
+                    hash_node = expr_to_proxy[sym_expr].node  # type: ignore[arg-type]
+                    node.replace_all_uses_with(hash_node)
+                    gm.graph.erase_node(node)
+                    log.debug(
+                        "CSE node %s -> %s for expr %s",
+                        node,
+                        hash_node,
+                        sym_expr,
+                    )
+
+                # store node in hash cons, don't delete/replace
+
+                elif sym_expr not in expr_to_proxy and not isinstance(
+                    sym_expr,
+                    (sympy.Number, sympy.logic.boolalg.BooleanAtom),
+                ):  # don't hash cons primitives
+                    expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer)  # type: ignore[arg-type]
+
+            # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
+            # so calls before that are redundant.
+            if node.target in (
+                torch.ops.aten.sym_constrain_range.default,
+                torch.ops.aten.sym_constrain_range_for_size.default,
+            ):
+                gm.graph.erase_node(node)
+
+            defs = []
+
+            # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as
+            # equivalent, but the refinement calls we perform in this pass may struggle with associating the two.
+            # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough
+            # information about the old symbol when we re-export, raising errors on data-dependent guards.
+            # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is.
+            if unbacked_bindings := resolve_unbacked_bindings(
+                shape_env, node.meta.get("unbacked_bindings")
+            ):
+                for s, keypath in unbacked_bindings.items():
+                    defs.append(s)
+
+                    # TODO: some CSE when generating these nodes can probably
+                    # help reduce graph size and improve compile time
+                    def go(node, keypath):
+                        if keypath == ():
+                            return node
+                        if (
+                            len(keypath) >= 2
+                            and isinstance(keypath[0], CallMethodKey)
+                            and isinstance(keypath[1], pytree.SequenceKey)
+                        ):
+                            if keypath[0].name == "size":
+                                return go(
+                                    graph.call_function(
+                                        torch.ops.aten.sym_size.int,
+                                        (node, keypath[1].idx),
+                                    ),
+                                    keypath[2:],
+                                )
+                            if keypath[0].name == "stride":
+                                return go(
+                                    graph.call_function(
+                                        torch.ops.aten.sym_stride.int,
+                                        (node, keypath[1].idx),
+                                    ),
+                                    keypath[2:],
+                                )
+
+                            return go(
+                                graph.call_method(
+                                    keypath[0].name, (node, keypath[1].idx)
+                                ),
+                                keypath[2:],
+                            )
+                        elif isinstance(keypath[0], CallMethodKey):
+                            if keypath[0].name == "storage_offset":
+                                return go(
+                                    graph.call_function(
+                                        torch.ops.aten.sym_storage_offset.default,
+                                        (node,),
+                                    ),
+                                    keypath[1:],
+                                )
+
+                            return go(
+                                graph.call_method(keypath[0].name, (node,)), keypath[1:]
+                            )
+                        elif isinstance(keypath[0], pytree.SequenceKey):
+                            return go(
+                                graph.call_function(
+                                    operator.getitem, (node, keypath[0].idx)
+                                ),
+                                keypath[1:],
+                            )
+                        elif isinstance(keypath[0], ConvertIntKey):
+                            return go(
+                                graph.call_function(
+                                    cast_symbool_to_symint_guardless, (node,)
+                                ),
+                                keypath[1:],
+                            )
+                        elif isinstance(keypath[0], DivideByKey):
+                            # TODO: need to assert divisibility
+                            return go(
+                                graph.call_function(
+                                    operator.floordiv, (node, keypath[0].divisor)
+                                ),
+                                keypath[1:],
+                            )
+                        elif isinstance(keypath[0], InnerTensorKey):
+                            return go(
+                                graph.call_function(
+                                    getattr, (node, keypath[0].inner_name)
+                                ),
+                                keypath[1:],
+                            )
+                        else:
+                            raise AssertionError(f"unrecognized keypath {keypath}")
+
+                    if s not in expr_to_proxy:
+                        with _set_node_metadata_hook(gm, _node_metadata_hook):
+                            expr_to_proxy[s] = fx.Proxy(
+                                go(node, keypath), tracer=tracer
+                            )
+                        log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
+
+            for i0 in defs:
+                ras = ras_by_symbol.pop(i0, [])
+                # Before we perform any asserts, first apply range
+                # refinement.  This is important, because if we are going
+                # to retrace the graph (and we typically are if we send
+                # the graph to AOTAutograd), we need to make sure we apply
+                # range refinement (ala _check_is_size) first, BEFORE we
+                # run any of the asserts.  Otherwise, we may decide to
+                # perform substitutions based on the asserts which we then
+                # can't back out, because value ranges can only be applied
+                # to asserts.)
+                #
+                # A perhaps better long term plan is to avoid this order
+                # dependence by making it possible to refine ranges on
+                # arbitrary expressions, not just symbols.  But it is not
+                # so easy to make use of this information, see
+                # https://twitter.com/ezyang/status/1745801370299482492
+                # We actually made an attempt at this in
+                # https://github.com/pytorch/pytorch/pull/119043
+                # which didn't work.
+                #
+                # Another ideas for how to do this:
+                # - Have bound_sympy be the source of truth of the ranges of any expression
+                # - Cache intermediate results for every subexpression of bound_sympy
+                # - This cache should be possible to edit to refine ranges
+                #
+                # One issue with this proposal is that if
+                # we have a bound on 2x, we are not going to be able to
+                # apply it for 4x.  Similarly, we may have bounds for an
+                # equivalent expression that we are not applying because
+                # it's not a perfect match (e.g. x < y vs y > x)".
+                #
+                # The first issue we already have it and it's impossible
+                # to solve in general, so any implementation on a best
+                # effort basis should do.
+                #
+                # The second issue is a preexisting one. It can be mitigated
+                # with a normalization algorithm. In general, it may also
+                # be on a best effort basis, but since our grammar is not
+                # terribly difficult, chances are we could even fully
+                # normalize SymPy expressions... who knows.
+                if i0 in constrained_unbacked_symbols:
+                    continue  # constrain symbol just once
+
+                if i0 in shape_env.size_like:
+                    if export:
+                        graph.call_function(
+                            torch.ops.aten.sym_constrain_range_for_size.default,
+                            (expr_to_proxy[i0].node,),
+                        )
+                    else:
+                        graph.call_function(
+                            torch._check_is_size, (expr_to_proxy[i0].node,)
+                        )
+
+                vr = shape_env.var_to_range[i0]
+                if vr.is_int and vr.upper == sys.maxsize - 1:
+                    # treat upper bound == sys.maxsize - 1 for int symbols as +oo
+                    # to avoid redundant runtime assert
+                    vr = ValueRanges(vr.lower, int_oo)
+                if not shape_env._default_unspecified_value_range().issubset(vr):
+                    # The runtime range is constrained, so add a runtime
+                    # assert and also explicitly refine the range
+                    # (refinement should not be necessary once runtime
+                    # asserts cause refinement, but that's NYI)
+                    def convert(s):
+                        if s in (int_oo, -int_oo):
+                            return None
+                        try:
+                            return int(s)
+                        except TypeError:
+                            return None
+
+                    if (
+                        expr_to_proxy[i0].node.target
+                        is not cast_symbool_to_symint_guardless
+                    ):
+                        # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
+                        # raises AOTAutograd errors on cast_symbool_to_symint_guardless
+
+                        with _set_node_metadata_hook(
+                            gm,
+                            functools.partial(
+                                _node_metadata_hook,
+                                stack_trace=node.meta.get("stack_trace"),
+                                nn_module_stack=node.meta.get("nn_module_stack"),
+                                # nodes added in `apply_runtime_assertion_pass` will have the same annotation
+                                # as the input node to the assertion
+                                custom=node.meta.get("custom"),
+                            ),
+                        ):
+                            if (min_val := convert(vr.lower)) is not None:
+                                ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
+                                graph.call_function(
+                                    torch.ops.aten._assert_scalar.default,
+                                    (
+                                        ge,
+                                        f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
+                                    ),
+                                )
+                                added_asserts.add(i0 >= min_val)
+                            if (max_val := convert(vr.upper)) is not None:
+                                le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
+                                graph.call_function(
+                                    torch.ops.aten._assert_scalar.default,
+                                    (
+                                        le,
+                                        f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
+                                    ),
+                                )
+                                added_asserts.add(i0 <= max_val)
+
+                constrained_unbacked_symbols.add(i0)
+                add_runtime_asserts(ras)
+
+    # delete unused reified symbols
+    for expr, proxy in expr_to_proxy.items():
+        if (
+            isinstance(expr, sympy.Symbol)
+            and proxy.node.op != "placeholder"  # keep placeholders intact
+            and not proxy.node.users
+        ):
+            log.debug("deleting unused reified symbol for %s", expr)
+            gm.graph.erase_node(proxy.node)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ea218356138de640c7fb7a74fb2efbcb4b21e5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py
@@ -0,0 +1,230 @@
+# mypy: ignore-errors
+
+import traceback
+from typing import Any, NamedTuple, Optional
+
+import torch
+import torch.fx
+from torch._dispatch.python import enable_python_dispatcher
+from torch._guards import detect_fake_mode
+from torch._prims_common import is_contiguous_for_memory_format_or_false
+from torch._subclasses.meta_utils import is_sparse_any
+from torch.fx._compatibility import compatibility
+from torch.fx.node import map_aggregate, Node
+
+
+__all__ = ["TensorMetadata", "ShapeProp"]
+
+
+@compatibility(is_backward_compatible=True)
+class TensorMetadata(NamedTuple):
+    # TensorMetadata is a structure containing pertinent information
+    # about a tensor within a PyTorch program.
+
+    # General Tensor metadata
+    shape: torch.Size
+    dtype: torch.dtype
+    requires_grad: bool
+    stride: tuple[int, ...]
+    memory_format: Optional[torch.memory_format]
+
+    # Quantization metadata
+    is_quantized: bool
+    qparams: dict[str, Any]
+
+
+# When include_contiguity is True, we will set contiguity when its always true for the tensor.
+# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
+# In such situation contiguity is not set. We could also make it a tri-state i.e: (def_contiguous,
+# def_not_contiguous and unknown).
+def _extract_tensor_metadata(
+    result: torch.Tensor, include_contiguity=True
+) -> TensorMetadata:
+    """
+    Extract a TensorMetadata NamedTuple describing `result`.
+    """
+    shape = result.shape
+    dtype = result.dtype
+    requires_grad = result.requires_grad
+    stride = result.stride() if not is_sparse_any(result) else ()
+
+    memory_format = None
+
+    if include_contiguity and not is_sparse_any(result):
+        memory_formats = (
+            torch.contiguous_format,
+            torch.channels_last,
+            torch.channels_last_3d,
+        )
+        for query_format in memory_formats:
+            if is_contiguous_for_memory_format_or_false(
+                result, memory_format=query_format
+            ):
+                memory_format = query_format
+                break
+
+    is_quantized = result.is_quantized
+    qparams: dict[str, Any] = {}
+    if is_quantized:
+        qscheme = result.qscheme()
+        qparams["qscheme"] = qscheme
+        if qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric):
+            qparams["scale"] = result.q_scale()  # type: ignore[assignment]
+            qparams["zero_point"] = result.q_zero_point()  # type: ignore[assignment]
+        elif qscheme in (
+            torch.per_channel_affine,
+            torch.per_channel_affine_float_qparams,
+            torch.per_channel_symmetric,
+        ):
+            # In this branch, scale and zero_point are expected to be tensors,
+            # we store the values as immutable_list in TensorMetadata for
+            # easier serialization downstream
+            qparams["scale"] = result.q_per_channel_scales().tolist()  # type: ignore[assignment]
+            qparams["zero_point"] = result.q_per_channel_zero_points().tolist()  # type: ignore[assignment]
+            qparams["axis"] = result.q_per_channel_axis()  # type: ignore[assignment]
+
+    return TensorMetadata(
+        shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams
+    )
+
+
+@compatibility(is_backward_compatible=True)
+class ShapeProp(torch.fx.Interpreter):
+    """
+    Execute an FX graph Node-by-Node and
+    record the shape and type of the result
+    into the corresponding node.
+
+    Example:
+         In this example, we record the shape
+         and data type of a module given
+         an example input ``torch.randn(50, D_in)``.
+         We print the name, shape and dtype of each node.
+
+        class TwoLayerNet(torch.nn.Module):
+            def __init__(self, D_in, H, D_out):
+                super().__init__()
+                self.linear1 = torch.nn.Linear(D_in, H)
+                self.linear2 = torch.nn.Linear(H, D_out)
+            def forward(self, x):
+                h_relu = self.linear1(x).clamp(min=0)
+                y_pred = self.linear2(h_relu)
+                return y_pred
+        N, D_in, H, D_out = 64, 1000, 100, 10
+        x = torch.randn(N, D_in)
+        y = torch.randn(N, D_out)
+        model = TwoLayerNet(D_in, H, D_out)
+        gm = torch.fx.symbolic_trace(model)
+        sample_input = torch.randn(50, D_in)
+        ShapeProp(gm).propagate(sample_input)
+
+        for node in gm.graph.nodes:
+            print(node.name, node.meta['tensor_meta'].dtype,
+                node.meta['tensor_meta'].shape)
+
+        The output of this code is:
+
+        x torch.float32 torch.Size([50, 1000])
+        linear1 torch.float32 torch.Size([50, 100])
+        clamp_1 torch.float32 torch.Size([50, 100])
+        linear2 torch.float32 torch.Size([50, 10])
+        output torch.float32 torch.Size([50, 10])
+
+    Args:
+         module (GraphModule): The module to be executed
+         fake_mode (FakeTensorMode): A fake mode for copying the gm
+
+    """
+
+    def __init__(self, gm, fake_mode=None):
+        super().__init__(gm)
+        if fake_mode is None:
+            fake_mode = detect_fake_mode()
+        if fake_mode is not None:
+            from torch._dynamo.utils import deepcopy_to_fake_tensor
+
+            # Note:
+            # We need fake execution cause the inputs are fake, however, we cannot fakify the module
+            # - because we need to write to the tensor_meta of the real module. So we fakify to
+            # produce a result (L131 below), to extract tensor meta, and then keep going.
+            #
+            # If we were to fakify, we would write to the wrong node, and then downstream fusion
+            # would be missing the tensor_meta.
+            #
+            # See torch/_inductor/overrides.py for where this is called upstream of fusion.
+            self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
+            self.fake_mode = fake_mode
+        else:
+            self.fake_module = None
+            self.fake_mode = None
+
+        self.real_module = self.module
+
+    def run_node(self, n: Node) -> Any:
+        from torch.fx.experimental.symbolic_shapes import (
+            compute_unbacked_bindings,
+            rebind_unbacked,
+        )
+
+        try:
+            if self.fake_module is not None:
+                # Hacky swap. Alternatively, we could do this with overriding
+                # call_module and get_attr.
+                self.module = self.fake_module
+            try:
+                if self.fake_mode is not None:
+                    with self.fake_mode, enable_python_dispatcher():
+                        result = super().run_node(n)
+                        rebind_unbacked(self.fake_mode.shape_env, n, result)
+                else:
+                    result = super().run_node(n)
+            finally:
+                self.module = self.real_module
+        except Exception as e:
+            traceback.print_exc()
+            raise RuntimeError(
+                f"ShapeProp error for: node={n.format_node()} with meta={n.meta}"
+            ) from e
+
+        found_tensor = False
+
+        def extract_tensor_meta(obj):
+            if isinstance(obj, torch.Tensor):
+                nonlocal found_tensor
+                found_tensor = True
+                return _extract_tensor_metadata(obj)
+            else:
+                return obj
+
+        meta = map_aggregate(result, extract_tensor_meta)
+        if found_tensor:
+            n.meta["tensor_meta"] = meta
+
+        if self.fake_mode:
+            if (shape_env := self.fake_mode.shape_env) and (
+                symbol_to_path := compute_unbacked_bindings(shape_env, result)
+            ):
+                n.meta["unbacked_bindings"] = symbol_to_path
+
+        n.meta["type"] = type(result)
+        return result
+
+    def propagate(self, *args):
+        """
+        Run `module` via interpretation and return the result and
+        record the shape and type of each node.
+
+        Args:
+            *args (Tensor): the sample input.
+
+        Returns:
+            Any: The value returned from executing the Module
+        """
+        if self.fake_mode is not None:
+            fake_args = [
+                self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
+                for t in args
+            ]
+        else:
+            fake_args = args
+        return super().run(*fake_args)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/split_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/split_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4b244750f33dc5f5a7b233afc70f4b1e1f26cd8
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/split_module.py
@@ -0,0 +1,656 @@
+# mypy: allow-untyped-defs
+import inspect
+import logging
+from collections import OrderedDict
+from collections.abc import Callable
+from typing import Any, Optional
+
+import torch
+from torch.fx._compatibility import compatibility
+from torch.fx._utils import lazy_format_graph_code
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import Node
+
+
+__all__ = ["Partition", "split_module"]
+log = _LOGGER = logging.getLogger(__name__)
+
+
+@compatibility(is_backward_compatible=True)
+class Partition:
+    def __init__(self, name: str):
+        self.name: str = name
+        self.submod_name = f"submod_{name}"
+        self.node_names: list[str] = []
+        self.inputs: dict[str, None] = {}
+        self.outputs: dict[str, None] = {}
+        self.dependencies: dict[str, None] = {}
+        self.dependents: dict[str, None] = {}
+        self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
+        self.environment: dict[Node, Node] = {}
+        self.targets: dict[str, Any] = {}
+
+    def __repr__(self) -> str:
+        return (
+            f"name: {self.name},\n"
+            f" nodes: {self.node_names},\n"
+            f" inputs: {self.inputs},\n"
+            f" outputs: {self.outputs},\n"
+            f" partitions depended on: {self.dependencies},\n"
+            f" partition dependents: {self.dependents}"
+        )
+
+
+def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any:
+    attr_val = mod
+    for atom in qualname.split("."):  # type: ignore[union-attr]
+        if not hasattr(attr_val, atom):
+            raise AttributeError(f"Node target {qualname} not found!")
+        attr_val = getattr(attr_val, atom)
+    return attr_val
+
+
+# Creates subgraphs out of main graph
+@compatibility(is_backward_compatible=True)
+def split_module(
+    m: GraphModule,
+    root_m: torch.nn.Module,
+    split_callback: Callable[[Node], int],
+    qualname_map: Optional[dict[str, str]] = None,
+    keep_original_order: Optional[bool] = False,
+    keep_original_node_name: Optional[bool] = False,
+    keep_original_input_name: bool = True,
+    *,
+    partition_affix: Optional[str] = None,
+):
+    """
+    Creates subgraphs out of main graph
+
+    Args:
+        m (GraphModule): Graph module to split
+        root_m (torch.nn.Module): root nn module. Not currently used. Included
+            because the root nn module is usually transformed via
+            torch.fx._symbolic_trace.symbolic_trace (see example below)
+        split_callback (Callable[[Node], int]): Callable function
+            that maps a given Node instance to a numeric partition identifier.
+            split_module will use this function as the policy for which operations
+            appear in which partitions in the output Module.
+        qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
+            mapping from new target names in the module after split to old target
+            names in the original module.
+        keep_original_order: Optional[bool]: keep the original order of the GraphModule
+            or use the Topological order of the new constructed GraphModule
+        keep_original_node_name: Optional[bool]: If the partitioned graphs should
+            have the same node names as the original graph.
+        keep_original_input_name: bool: If the partitioned graphs should
+            have the same input names as the original graph.
+        partition_affix: Optional[str]: If specified, the submodules' names will contain
+            the affix, e.g. "submod__".
+
+    Returns:
+        GraphModule: the module after split.
+
+    Example:
+
+        This is a sample setup:
+
+            import torch
+            from torch.fx._symbolic_trace import symbolic_trace
+            from torch.fx.graph_module import GraphModule
+            from torch.fx.node import Node
+            from torch.fx.passes.split_module import split_module
+
+            class MyModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.param = torch.nn.Parameter(torch.rand(3, 4))
+                    self.linear = torch.nn.Linear(4, 5)
+
+                def forward(self, x, y):
+                    z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
+                    w = self.linear(y).clamp(min=0.0, max=1.0)
+                    return z + w
+
+            # symbolically trace model
+            my_module = MyModule()
+            my_module_traced = symbolic_trace(my_module)
+
+            # random mod partitioning
+            partition_counter = 0
+            NPARTITIONS = 3
+
+            def mod_partition(node: Node):
+                global partition_counter
+                partition = partition_counter % NPARTITIONS
+                partition_counter = (partition_counter + 1) % NPARTITIONS
+                return partition
+
+            # split module in module with submodules
+            module_with_submodules = split_module(
+                my_module_traced, my_module, mod_partition
+            )
+
+        Output looks like this. Original graph is broken into partitions
+
+            > print(module_with_submodules)
+            GraphModule(
+                (submod_0): GraphModule(
+                    (linear): Linear(in_features=4, out_features=5, bias=True)
+                )
+                (submod_1): GraphModule(
+                    (linear): Linear(in_features=4, out_features=5, bias=True)
+                )
+                (submod_2): GraphModule()
+            )
+
+            def forward(self, x, y):
+                param = self.param
+                submod_0 = self.submod_0(x, param, y);  x = param = y = None
+                getitem = submod_0[0]
+                getitem_1 = submod_0[1];  submod_0 = None
+                submod_1 = self.submod_1(getitem, getitem_1);  getitem = getitem_1 = None
+                getitem_2 = submod_1[0]
+                getitem_3 = submod_1[1];  submod_1 = None
+                submod_2 = self.submod_2(getitem_2, getitem_3);  getitem_2 = getitem_3 = None
+                return submod_2
+
+        Output of split module is the same as output of input traced module.
+        This is an example within a test setting:
+
+            > orig_out = my_module_traced(x, y)
+            > submodules_out = module_with_submodules(x, y)
+            > self.assertEqual(orig_out, submodules_out)
+            True
+    """
+
+    log.debug(
+        "%s",
+        lazy_format_graph_code("pre split_module", m, colored=True),
+    )
+
+    def construct_graph(
+        node: Node,
+        base_mod_env: dict[str, Node],
+        base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
+    ):
+        if node.op == "placeholder":
+            default_value = (
+                node.args[0] if len(node.args) > 0 else inspect.Signature.empty
+            )
+            if keep_original_node_name:
+                args = (
+                    () if default_value is inspect.Signature.empty else (default_value,)
+                )
+                base_mod_env[node.name] = base_mod_graph.create_node(
+                    "placeholder",
+                    node.name,
+                    args=args,  # type: ignore[arg-type]
+                    type_expr=node.type,
+                )
+            else:
+                base_mod_env[node.name] = base_mod_graph.placeholder(
+                    node.target,  # type: ignore[arg-type]
+                    type_expr=node.type,
+                    default_value=default_value,
+                )
+            base_mod_env[node.name].meta = node.meta.copy()
+        elif node.op == "get_attr":
+            base_mod_env[node.name] = base_mod_graph.get_attr(node.target)  # type: ignore[arg-type]
+            base_mod_env[node.name].meta = node.meta.copy()
+            assert isinstance(node.target, str)
+            attr_val = _get_attr_from_qualname(m, node.target)
+            base_mod_attrs[node.target] = attr_val  # type: ignore[index]
+        return base_mod_env, base_mod_attrs
+
+    import sympy
+
+    partitions: dict[str, Partition] = {}
+    orig_nodes: dict[str, Node] = {}
+    symbol_to_node: dict[sympy.Symbol, Node] = {}
+
+    def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
+        from torch.fx.experimental.symbolic_shapes import free_symbols
+
+        defined = getattr(def_node, "_fx_partition", None)
+        used = getattr(use_node, "_fx_partition", None)
+
+        log.debug(
+            "record_cross_partition_use %s (%s) %s (%s)",
+            def_node.name,
+            defined,
+            use_node.name if use_node is not None else "-",
+            used,
+        )
+
+        if defined != used:
+            if defined is not None:
+                def_partition = partitions[defined]
+                def_partition.outputs.setdefault(def_node.name)
+                if used is not None:
+                    def_partition.dependents.setdefault(used)
+
+            if used is not None:
+                use_partition = partitions[used]
+                use_partition.inputs.setdefault(def_node.name)
+                # We have made def_node an input to the use_partition.  If
+                # this input has symbolic symbols in its size, those also must
+                # be made as inputs to the partition
+                if (def_val := def_node.meta.get("example_value")) is not None:
+                    for s in sorted(free_symbols(def_val), key=str):
+                        s_node = symbol_to_node[s]
+                        use_partition.inputs.setdefault(s_node.name)
+                        if symbol_to_node[s].op != "placeholder":
+                            # If the node that defines the symbol is not a
+                            # placeholder, we must make it an output of the
+                            # partition.  Note that this may be in a different
+                            # partition than defined!  Although, this doesn't
+                            # really make a difference for correctness, since
+                            # defined is guaranteed to have the symbol in
+                            # scope and can return it; you just get less
+                            # optimal codegen in this case.
+                            s_defined = getattr(s_node, "_fx_partition", None)
+                            if s_defined is not None:
+                                s_def_partition = partitions[s_defined]
+                                s_def_partition.outputs.setdefault(s_node.name)
+                                s_def_partition.dependents.setdefault(used)
+                                use_partition.dependencies.setdefault(s_defined)
+                if defined is not None:
+                    use_partition.dependencies.setdefault(defined)
+
+    def instantiate_node_partition_mapping(node):
+        partition_idx = split_callback(node)
+        partition_name = str(partition_idx)
+        if partition_affix is not None:
+            # For example, if user specifies partition_affix = "pp", then the
+            # partition name will be "pp_0", "pp_1", etc
+            partition_name = "_".join([partition_affix, partition_name])
+
+        log.debug(
+            "instantiate_node_partition_mapping %s (%s)", node.name, partition_name
+        )
+
+        # add node to partitions
+        partition = partitions.get(partition_name)
+        if partition is None:
+            partitions[partition_name] = partition = Partition(partition_name)
+
+        partition.node_names.append(node.name)
+        node._fx_partition = partition_name
+
+    # Global State Nodes are nodes which by their global state effects,
+    # "taint" all downstream nodes while they are active.
+    GLOBAL_STATE_NODES = [
+        torch.amp._enter_autocast,
+        torch.amp._exit_autocast,
+        torch._C._set_grad_enabled,
+    ]
+
+    # For grad regions:
+    # ------------------------
+    # 1. first region: we do nothing
+    # 2. subsequent regions: we insert the set_grad at the beginning
+    grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
+
+    # For autocast regions:
+    # ------------------------
+    # 1. first region: we will only insert the _exit at the end
+    # 2. intermediate regions: we will insert both the
+    #    _enter at the beginning and _exit at the end
+    # 3. last region: we will only insert _enter at the beginning
+    # We will do so in the order in which the autocasts were instantiated.
+    autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
+    autocast_exits: dict[Node, Optional[Node]] = {}
+
+    active_grad = None
+    active_autocasts = set()
+
+    for node in m.graph.nodes:
+        # This will prefer placeholder bindings, because those come first.
+        # This is a little dangerous though: it is possible that an unbacked
+        # symbol is used without any binding site for it, in which case we
+        # will get a KeyError not able to find it.  I'd like to fix this by
+        # having passes.runtime_assert establish some invariants that I can
+        # rely on later, but this needs some extra work.  Quick fix first.
+        # See https://github.com/pytorch/pytorch/issues/130534
+        if (
+            (val := node.meta.get("example_value")) is not None
+            and isinstance(val, (torch.SymInt, torch.SymFloat))
+            and isinstance(s0 := val.node.expr, sympy.Symbol)
+            and s0 not in symbol_to_node
+        ):
+            symbol_to_node[val.node.expr] = node
+
+        if node.op in ["placeholder", "get_attr", "output"]:
+            continue
+
+        instantiate_node_partition_mapping(node)
+
+        if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
+            if node.target is torch._C._set_grad_enabled:
+                assert len(node.args) == 1
+                assert isinstance(node.args[0], bool)
+                active_grad = node
+                grad_regions[active_grad] = set({split_callback(node)})
+            elif node.target is torch.amp._enter_autocast:
+                # Should all be python constants
+                assert all(not isinstance(arg, Node) for arg in node.args)
+                active_autocasts.add(node)
+                autocast_regions[node] = set({split_callback(node)})
+                autocast_exits[node] = None
+            elif node.target is torch.amp._exit_autocast:
+                assert len(node.args) == 1
+                autocast_regions[node.args[0]].add(split_callback(node))
+                active_autocasts.remove(node.args[0])
+                autocast_exits[node.args[0]] = node
+
+        if active_grad is not None:
+            grad_regions[active_grad].add(split_callback(node))
+
+        for a in active_autocasts:
+            autocast_regions[a].add(split_callback(node))
+
+    assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
+
+    # pyrefly: ignore [bad-assignment]
+    autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
+    # pyrefly: ignore [bad-assignment]
+    grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
+
+    if _LOGGER.isEnabledFor(logging.DEBUG):
+        _LOGGER.debug("autocast_regions: %s", autocast_regions)
+        _LOGGER.debug("grad_regions: %s", grad_regions)
+
+    assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
+
+    # split nodes into partitions
+    highest_partition = -1
+    for node in m.graph.nodes:
+        orig_nodes[node.name] = node
+
+        # TODO currently placeholders/parameters aren't put into random partitions,
+        # rather they're added to the graphs where they are used down below
+        if node.op in ["placeholder", "get_attr"]:
+            continue
+        if node.op == "output":
+            torch.fx.graph.map_arg(
+                node.args[0], lambda n: record_cross_partition_use(n, None)
+            )
+            continue
+
+        if assert_monotonically_increasing:
+            pid = split_callback(node)
+            assert highest_partition <= pid, (
+                "autocast or set_grad_enabled require monotonically increasing partitions:"
+                f"highest: {highest_partition}, this node's: {pid}"
+            )
+            highest_partition = pid
+
+        # do not capture cross-partition dependencies for global state nodes as they will be
+        # self-contained - their setup and unwind will be isolated to each partition submodule.
+        if node.target not in GLOBAL_STATE_NODES:
+            torch.fx.graph.map_arg(
+                node.args, lambda def_node: record_cross_partition_use(def_node, node)
+            )
+            torch.fx.graph.map_arg(
+                node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
+            )  # noqa: B950
+
+    original_partition_order = list(partitions.keys())
+    # find partitions with no dependencies
+    root_partitions: list[str] = []
+    for partition_name, partition in partitions.items():
+        if not len(partition.dependencies):
+            root_partitions.append(partition_name)
+
+    # check partitions for circular dependencies and create topological partition ordering
+    sorted_partitions: list[str] = []
+    while root_partitions:
+        root_partition = root_partitions.pop()
+        sorted_partitions.append(root_partition)
+        for dependent in partitions[root_partition].dependents:
+            partitions[dependent].dependencies.pop(root_partition)  # noqa: B909
+            if not partitions[dependent].dependencies:
+                root_partitions.append(dependent)
+    if len(sorted_partitions) != len(partitions):
+        raise RuntimeError("cycle exists between partitions!")
+
+    # Enter prelude
+    for regions_mapping in [autocast_regions, grad_regions]:
+        for node, regions in regions_mapping.items():
+            assert len(regions) > 0
+            # pyrefly: ignore [index-error]
+            partitions[str(regions[0])].environment[node] = node
+            # pyrefly: ignore [index-error]
+            for r in regions[1:]:
+                partition = partitions[str(r)]
+                new_node = partition.graph.create_node(
+                    op=node.op,
+                    target=node.target,
+                    args=tuple(arg for arg in node.args),
+                    kwargs={},
+                    type_expr=node.type,
+                )
+                new_node.meta = (
+                    node.meta.copy()
+                )  # is it really a good idea to copy this?
+                partition.environment[node] = new_node
+
+    # add placeholders to partition inputs
+    for partition_name in sorted_partitions:
+        partition = partitions[partition_name]
+        new_inputs: dict[str, None] = {}
+
+        counter = 0
+
+        for inp in partition.inputs:
+            orig_node = orig_nodes[inp]
+            # We don't pass in get_attr nodes as inputs to the partition, but
+            # instead set them as targets and use getattr within the module
+
+            def add_placeholder():
+                if keep_original_input_name:
+                    name = inp
+                else:
+                    nonlocal counter
+                    name = f"arg_{counter}"
+                    counter += 1
+                placeholder = partition.graph.placeholder(
+                    name,
+                    type_expr=orig_nodes[inp].type,
+                )
+                new_inputs[inp] = None
+                return placeholder
+
+            if orig_node.op == "get_attr":
+                assert isinstance(orig_node.target, str)
+
+                orig_attr = _get_attr_from_qualname(m, orig_node.target)
+                if isinstance(orig_attr, torch.nn.Module):
+                    placeholder = partition.graph.get_attr(orig_node.target)
+                    partition.targets[orig_node.target] = orig_attr
+                else:
+                    placeholder = add_placeholder()
+            else:
+                placeholder = add_placeholder()
+            placeholder.meta = orig_nodes[inp].meta.copy()
+            partition.environment[orig_nodes[inp]] = placeholder
+        partition.inputs = new_inputs
+
+    # Transform nodes and collect targets for partition's submodule
+    for node in m.graph.nodes:
+        if hasattr(node, "_fx_partition"):
+            partition = partitions[node._fx_partition]
+
+            # swap out old graph nodes in kw/args with references to new nodes in this submodule
+            environment = partition.environment
+            gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
+            gathered_kwargs = torch.fx.graph.map_arg(
+                node.kwargs, lambda n: environment[n]
+            )
+
+            if node.op not in ["call_module", "get_attr"]:
+                target = node.target
+            else:
+                target_attr = _get_attr_from_qualname(m, node.target)
+                target = node.target.replace(".", "_")
+                partition.targets[target] = target_attr
+                # Fill in the passed-in mapping from new qualname to old qualname
+                if qualname_map is not None:
+                    # When creating the split module later, the submodules will have
+                    # path prefix matching the corresponding partition's submod_name
+                    qualname = f"{partition.submod_name}.{target}"
+                    qualname_map[qualname] = node.target
+
+            assert isinstance(gathered_args, tuple)
+            assert isinstance(gathered_kwargs, dict)
+            name = node.name if keep_original_node_name else None
+            new_node = partition.graph.create_node(
+                op=node.op,
+                target=target,
+                args=gathered_args,
+                kwargs=gathered_kwargs,
+                type_expr=node.type,
+                name=name,
+            )
+            new_node.meta = node.meta.copy()
+            partition.environment[node] = new_node
+
+    # Exit epilogue
+    for regions_mapping in [autocast_regions]:
+        for node in reversed(regions_mapping):
+            regions = regions_mapping[node]
+            assert len(regions) > 0
+            # pyrefly: ignore [index-error]
+            for r in regions[:-1]:
+                partition = partitions[str(r)]
+                exit_node = autocast_exits[node]
+                assert exit_node is not None, "Missing exit node"
+                new_node = partition.graph.create_node(
+                    op=exit_node.op,
+                    target=exit_node.target,
+                    args=(partition.environment[node],),
+                    kwargs={},
+                    type_expr=exit_node.type,
+                )
+                new_node.meta = (
+                    exit_node.meta.copy()
+                )  # is it really a good idea to copy this?
+
+    # original module environment dict mapping node names to nodes
+    orig_mod_env: dict[str, Node] = {}
+    # Set up values to construct base module
+    base_mod_env: dict[str, Node] = {}
+    base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
+    base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
+    if not keep_original_order:
+        for node in m.graph.nodes:
+            base_mod_env, base_mod_attrs = construct_graph(
+                node, base_mod_env, base_mod_attrs
+            )
+
+    else:
+        # Go through the graph to construct the mapping dict
+        for node in m.graph.nodes:
+            orig_mod_env[node.name] = node
+
+    # Do some things iterating over the partitions in topological order again:
+    # 1) Finish off submodule Graphs by setting corresponding outputs
+    # 2) Construct GraphModules for each submodule
+    # 3) Construct the base graph by emitting calls to those submodules in
+    #    topological order or original order specified by keep_original_order
+
+    construct_order_partitions = (
+        sorted_partitions if not keep_original_order else original_partition_order
+    )
+
+    already_constructed_attr_nodes = set()
+
+    # We actually need to insert the placeholder nodes in the original order
+    # otherwise graph signature will be wrong.
+    original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
+
+    for partition_name in construct_order_partitions:
+        partition = partitions[partition_name]
+
+        # Set correct output values
+        output_vals = tuple(
+            partition.environment[orig_nodes[name]] for name in partition.outputs
+        )
+
+        # skip output node generation if there are no output values
+        num_output_vals = len(output_vals)
+        if num_output_vals == 1:
+            partition.graph.output(output_vals[0])
+        elif num_output_vals > 1:
+            partition.graph.output(output_vals)
+        else:
+            # Invariant - Graph should always have an output node.
+            partition.graph.output(())
+
+        if keep_original_order:
+            # first get the attr nodes required by this partition
+            orig_mod_attr_nodes: list[Node] = [
+                orig_mod_env[key]
+                for key in partition.inputs
+                if key not in original_order
+            ]
+
+            for node in original_order:
+                if node in already_constructed_attr_nodes:
+                    continue  # already added this attr to the base graph
+                base_mod_env, _based_mod_attrs = construct_graph(
+                    node, base_mod_env, base_mod_attrs
+                )
+                already_constructed_attr_nodes.add(node)
+
+            # Construct GraphModule for this partition
+            for node in orig_mod_attr_nodes:  # type: ignore[attr-defined]
+                if node in already_constructed_attr_nodes:
+                    continue
+                base_mod_env, base_mod_attrs = construct_graph(
+                    node, base_mod_env, base_mod_attrs
+                )
+                already_constructed_attr_nodes.add(node)
+
+        base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
+            partition.targets, partition.graph
+        )  # noqa: B950
+
+        # Emit call in base graph to this submodule
+        output_val = base_mod_graph.call_module(
+            partition.submod_name,
+            tuple(base_mod_env[name] for name in partition.inputs),
+        )
+
+        num_outputs = len(partition.outputs)
+        if num_outputs > 1:
+            # Unpack multiple return values from submodule
+            output_val_proxy = torch.fx.proxy.Proxy(output_val)
+            for i, output_name in enumerate(partition.outputs):
+                base_mod_env[output_name] = output_val_proxy[i].node  # type: ignore[index]
+        elif num_outputs == 1:
+            base_mod_env[next(iter(partition.outputs))] = output_val
+
+    # When keep_original_order=True and if the graph doesn't have any
+    # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs`
+    # are never populated.
+    # For this case, we call `construct_graph` here which takes care of updating them.
+    if keep_original_order and not base_mod_env:
+        for node in m.graph.nodes:
+            base_mod_env, base_mod_attrs = construct_graph(
+                node, base_mod_env, base_mod_attrs
+            )
+
+    # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned.
+    for node in m.graph.nodes:
+        if node.op == "output":
+            base_mod_graph.output(
+                torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
+            )  # noqa: B950
+
+    ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
+    log.debug(
+        "%s",
+        lazy_format_graph_code("post split_module", ret, colored=True),
+    )
+    return ret
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/split_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/split_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..88da7ac7c4f55fb5cf1c22546d09ceb3b406d6fb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/split_utils.py
@@ -0,0 +1,312 @@
+# mypy: allow-untyped-defs
+import copy
+from dataclasses import dataclass, field
+from typing import Optional, Union
+
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import map_arg
+from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
+
+from .tools_common import NodeList
+
+
+__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
+
+
+@compatibility(is_backward_compatible=False)
+def getattr_recursive(obj, name):
+    for layer in name.split("."):
+        if isinstance(obj, torch.nn.ModuleList):
+            if hasattr(obj, "_modules") and layer in obj._modules:
+                obj = obj._modules[layer]
+            else:
+                return None
+        elif hasattr(obj, layer):
+            obj = getattr(obj, layer)
+        else:
+            return None
+    return obj
+
+
+@compatibility(is_backward_compatible=False)
+def setattr_recursive(obj, attr, value):
+    if "." not in attr:
+        setattr(obj, attr, value)
+    else:
+        layer = attr.split(".")
+        setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class Component:
+    """
+    A component serves as a container for a subgraph we want to create afterwards.
+    """
+
+    graph: torch.fx.Graph
+    order: int
+    name: str
+
+    # Stores the placeholder nodes in `graph`.
+    input_placeholders: list = field(default_factory=list)
+
+    # Store the nodes in original graph that are placeholder in `graph`.
+    orig_inputs: list = field(default_factory=list)
+
+    # Store the nodes in original graph that are outputs in `graph`.
+    orig_outputs: list = field(default_factory=list)
+
+    # Mapping from get_attr node in original graph to get_attr node in `graph`.
+    getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
+    constructor_args: list[str] = field(default_factory=list)
+    gm: Optional[torch.fx.GraphModule] = None
+
+
+@compatibility(is_backward_compatible=False)
+def split_by_tags(
+    gm: torch.fx.GraphModule,
+    tags: list[str],
+    return_fqn_mapping: bool = False,
+    return_tuple: bool = False,
+    GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
+) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
+    """
+    Splits a GraphModule using tags on its graph nodes. We honor the order of
+    tags. For example, we have tags = ["a", "b", "c"], the function will create
+    the initial submodules in the order of "a", "b", "c".
+
+    To set a tag:
+    gm.graph.nodes[idx].tag = "mytag"
+
+    This will result in all nodes with the same tag being extracted and placed in their
+    own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
+    and output nodes are created when needed while get_attr nodes get copied to submodules
+    where they are used.
+
+    Given the following module def:
+
+    class SimpleModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.linear1 = torch.nn.Linear(...)
+            self.linear2 = torch.nn.Linear(...)
+            self.linear3 = torch.nn.Linear(...)
+
+        def forward(self, in1, in2):
+            r1 = self.linear1(in1)
+            r2 = self.linear2(in2)
+            r3 = torch.cat([r1, r2])
+            return self.linear3(r3)
+
+    Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
+
+    ro:
+    def forward(self, in1):
+        self = self.root
+        linear1 = self.linear1(in1)
+        return linear1
+
+    main:
+    def forward(self, in2, linear1):
+        self = self.root
+        linear2 = self.linear2(in2)
+        cat_1 = torch.cat([linear1, linear2])
+        linear3 = self.linear3(cat_1)
+        return linear3
+
+    main:
+    def forward(self, in1, in2):
+        self = self.root
+        ro_0 = self.ro_0(in1)
+        main_1 = self.main_1(in2, ro_0)
+        return main_1
+
+    Returns:
+        split_gm: torch fx graph after split
+        orig_to_split_fqn_mapping: a map between the original fqn and the fqn
+            after split for call_module and get_attr.
+    """
+
+    def flatten(x: torch.fx.node.Argument) -> NodeList:
+        """
+        Stores nodes in x to a list and returns the list.
+        """
+        r: NodeList = []
+        map_arg(x, r.append)
+        return r
+
+    # Mapping from node in original module to node in created submodule.
+    node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Mapping from node in original module or created submodules to
+    # corresponding component.
+    node_to_component: dict[torch.fx.Node, Component] = {}
+
+    # Mapping from tag to the corresponding component.
+    tag_to_component: dict[str, Component] = {}
+
+    # Stores all components.
+    all_components: list[Component] = []
+
+    # Stores nodes that will be used in main graph.
+    used_in_main: dict[torch.fx.Node, None] = {}
+
+    # Main graph after split.
+    main_g = torch.fx.Graph()
+
+    # Mapping from node in original module to node in main graph after split.
+    main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Output node of original module.
+    output_node: Optional[torch.fx.Node] = None
+
+    # Create a component for each tag, we don't expect to create other components afterwards.
+    for tag in tags:
+        comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
+        all_components.append(comp)
+        tag_to_component[tag] = comp
+
+    # Traverse the nodes in original graph and take care of them.
+    for node in gm.graph.nodes:
+        if node.op == "output":
+            if output_node is not None:
+                raise RuntimeError("Multiple output nodes in graph!")
+            output_node = node
+            continue
+
+        # Placeholders in the original graph get copied to main graph.
+        if node.op == "placeholder":
+            main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
+            main_remapping[node].meta = copy.copy(node.meta)
+            continue
+
+        # Get_attr nodes are ignored because we are not tagging them.
+        # Instead, we copy them directly to the submodules use them afterwards.
+        if node.op == "get_attr":
+            continue
+
+        # Now we process callable nodes which are nodes with op of call_module,
+        # call_function or call_method. Every callable nodes should be tagged.
+        assert hasattr(node, "tag"), f"Node does not have tag: {node.format_node()}"
+
+        upstream_components = [
+            node_to_component[x]
+            for x in flatten(node.args) + flatten(node.kwargs)
+            if x.op not in {"placeholder", "get_attr"}
+        ]
+
+        comp = tag_to_component[node.tag]
+        node_to_component[node] = comp
+
+        # Max order of upperstream components.
+        mx = max((c.order for c in upstream_components), default=0)
+
+        # Expect the component for `node` has higher order then its upstream components.
+        assert comp.order >= mx, (
+            f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
+        )
+
+        # Map a input of `node` to nodes in the component's graph.
+        def remap_func(x):
+            # If input is a get_attr node, copy it to current component's graph.
+            # Returns the get_attr node in current component's graph.
+            if x.op == "get_attr":
+                if x not in comp.getattr_maps:
+                    comp.getattr_maps[x] = comp.graph.get_attr(
+                        x.target, type_expr=x.type
+                    )
+                    comp.getattr_maps[x].meta = copy.copy(x.meta)
+                return comp.getattr_maps[x]
+
+            # If input is not a placeholder, it should have been put into a component
+            # already. If it's the current component then we return the corresponding
+            # node in the component.
+            if x.op != "placeholder" and node_to_component[x] == comp:
+                return node_remapping[x]
+
+            # If input is a placeholder or it's in other components, we want to make it
+            # as a placeholder in current component's graph.
+            if x not in comp.orig_inputs:
+                comp.orig_inputs.append(x)
+                placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
+                placeholder.meta = copy.copy(x.meta)
+                comp.input_placeholders.append(placeholder)
+                used_in_main[x] = None
+
+            return comp.input_placeholders[comp.orig_inputs.index(x)]
+
+        n = comp.graph.node_copy(node, remap_func)
+        n.tag = node.tag  # type: ignore[attr-defined]
+        node_remapping[node] = n
+        node_to_component[n] = comp
+
+    if output_node is None:
+        raise RuntimeError("Graph had no output node!")
+
+    for x in flatten(output_node.args[0]):
+        if x.op == "get_attr":
+            # We don't need components mapping for nodes of type "get_attr"
+            # that are consumed by the output. Only need to make sure we create
+            # corresponding counterparts in the resulting graph.
+            main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
+        else:
+            # All component results consumed by the output node should be
+            # marked as "used in main".
+            used_in_main[x] = None
+
+    # If a node is used in main graph then we mark it as an output in the component
+    # it belongs to.
+    for n in used_in_main:
+        if n.op != "placeholder":
+            node_to_component[n].orig_outputs.append(n)
+
+    # Now we create a graphmodule for each component.
+    orig_to_split_fqn_mapping: dict[str, str] = {}
+    for comp in all_components:
+        outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
+
+        if return_tuple:
+            comp.graph.output(outs)
+        else:
+            # Take care of the args of FX output node. If there's a single
+            # output then the output node args is like (output_single), else
+            # if there're multiple outputs then the output node args is like
+            # ((output_0, output_1, ...)).
+            comp.graph.output(outs[0] if len(outs) == 1 else outs)
+
+        comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
+            gm, subgraph=comp.graph, comp_name=comp.name
+        )
+        orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
+
+        # Create a call_module node in main graph.
+        main_node = main_g.call_module(
+            comp.name,
+            args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
+            kwargs=None,
+        )
+
+        if len(outs) == 1 and not return_tuple:
+            main_remapping[comp.orig_outputs[0]] = main_node
+        else:
+            for i, o in enumerate(comp.orig_outputs):
+                # Use Proxy to record getitem access.
+                main_remapping[o] = torch.fx.Proxy(main_node)[i].node  # type: ignore[index]
+
+    main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
+    main_root = HolderModule({comp.name: comp.gm for comp in all_components})
+    main_g._codegen = gm.graph._codegen
+
+    # If the output nodes consumes get_attr directly in the original graph,
+    # then we need to make sure get_attr is copied to the new graph.
+    for x in flatten(output_node.args[0]):
+        if x.op == "get_attr":
+            setattr(main_root, x.name, getattr_recursive(gm, x.target))  # type: ignore[arg-type]
+
+    result_gm = GraphModuleCls(main_root, main_g)
+    if return_fqn_mapping:
+        return result_gm, orig_to_split_fqn_mapping
+
+    return result_gm
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d90f9d55cfdb194e2d2a0577a84b5fd9d7f0262
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py
@@ -0,0 +1,1121 @@
+# mypy: allow-untyped-defs
+import argparse
+import copy
+import json
+import logging
+import os
+from collections import defaultdict
+from collections.abc import Iterable, Sequence
+from dataclasses import dataclass
+from typing import Any, Literal, NamedTuple, Optional
+
+import torch
+from torch._logging import trace_structured
+from torch.fx._compatibility import compatibility
+from torch.fx.node import map_arg
+from torch.fx.passes.graph_manipulation import get_size_of_node
+
+from .graph_drawer import FxGraphDrawer
+from .operator_support import get_node_target, OperatorSupportBase
+from .shape_prop import ShapeProp
+from .split_utils import split_by_tags
+from .tools_common import (
+    CALLABLE_NODE_OPS,
+    FxNetAccFusionsFinder,
+    is_node_output_tensor,
+    NodeList,
+    NodeSet,
+    Tensors,
+)
+
+
+__all__ = [
+    "FxNetAccNodesFinder",
+    "FxNetSplitterInternalError",
+    "Subgraph",
+    "SplitResult",
+    "generate_inputs_for_submodules",
+    "NodeEvent",
+    "NodeEventTracker",
+]
+_LOGGER = logging.getLogger(__name__)
+
+DEFAULT_MIN_ACC_MODULE_SIZE = 1
+DEFAULT_SKIP_FUSION = False
+DEFAULT_ALLOW_NON_TENSOR = False
+
+# ENV var and constants for node tracker
+
+TRACKER_DUMP_PATH = "_fx_net_tracker"
+NODES_SUFFIX = "_nodes.txt"
+ALL_SUFFIX = "_all.txt"
+
+ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE = "FX_NET_ACC_SPLITTER_TRACKER_MODE"
+ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH = "FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH"
+ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES = (
+    "FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES"
+)
+
+DUMP_PREFIX = os.environ.get(
+    ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH, TRACKER_DUMP_PATH
+)
+
+"""
+Different modes of the event tracker for local debugging:
+"0": No local dumps. Information available by setting breakpoints and visually inspect in pdb.
+"1": Dump all events to DUMP_PREFIX_all.txt
+"2": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES
+     recursively and dump to DUMP_PREFIX_nodex.txt
+"3": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt
+In addition to the above local dumps, tracker is always enabled and dumps via trace_structured.
+"""
+TRACKER_MODE: Literal["0", "1", "2", "3"] = os.environ.get(
+    ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0"
+)  # type: ignore[assignment]
+
+
+class _SplitterSettingBase:
+    def __init__(
+        self,
+        min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
+        skip_fusion=DEFAULT_SKIP_FUSION,
+        allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
+        max_acc_splits: int = -1,
+    ):
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "--min-acc-module-size",
+            "--min_acc_module_size",
+            required=False,
+            type=int,
+            help="Minimum size limit of an accelerator subgraph.",
+        )
+        parser.add_argument(
+            "--max-acc-splits",
+            "--max_acc_splits",
+            required=False,
+            type=int,
+            help="Enforce a maximum number of split subgraphs.",
+        )
+        parser.add_argument(
+            "--skip-fusion",
+            "--skip_fusion",
+            default=False,
+            action="store_true",
+            help="If true then no fusion groups. Fusion group is used to "
+            "enforce no non-tensor data flow between submodules. If we don't "
+            "have this constrain, setting this to false is recommended as it "
+            "can reduce overhead.",
+        )
+        parser.add_argument(
+            "--allow-non-tensor",
+            "--allow_non_tensor",
+            default=False,
+            action="store_true",
+            help="For some backends non-tensor data flow between cpu and them "
+            "are not allowed. Therefore, if a node supported by accelerator but "
+            "it has non-tensor inputs or outputs to a cpu node we would want to "
+            "consider it as a cpu node during splitting. However, for some backends "
+            "we might not care about non-tensor data flow and we can set this option "
+            "to true to disable the functionality that prevent non-tensor data flow.",
+        )
+        args, _unknown = parser.parse_known_args()
+
+        self.min_acc_module_size: int = (
+            args.min_acc_module_size
+            if args.min_acc_module_size
+            else min_acc_module_size
+        )
+        self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
+        self.allow_non_tensor: bool = (
+            args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
+        )
+        self.max_acc_splits: int = max_acc_splits
+
+
+@compatibility(is_backward_compatible=False)
+class NodeEvent:
+    """
+    An event in graph split that happened on a node.
+    source: Subject of the event
+    desc: readable description
+    dep: Optional dependency, usually the node that caused the event.
+    """
+
+    def __init__(
+        self, source: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None
+    ):
+        self.source = source
+        self.desc = desc
+        self.dep = dep
+
+    def to_str(self):
+        # source: The name of the subject of the event.
+        # desc: description of the event, in the format of |
+        # dep: The name of the cause of this event, which is another node, or #
+        # if it's caused by the subject node
+        return f"{self.source.name}: {self.desc} {self.dep.name if self.dep else '#'}"
+
+
+@compatibility(is_backward_compatible=False)
+class NodeEventTracker:
+    """
+    Tracks node events during the splitter execution.
+    """
+
+    def __init__(self, tracker_mode, dump_prefix):
+        self.tracker_mode = tracker_mode
+        self.dump_prefix = dump_prefix
+        # list of events
+        self.events = []
+        # dict from node name to event index
+        self.node_events = {}
+        self.writer = print
+
+    def add(self, node: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None):
+        """
+        Add a new event to the tracker.
+        """
+        event = NodeEvent(node, desc, dep)
+        self.events.append(event)
+        if node.name not in self.node_events:
+            self.node_events[node.name] = []
+        self.node_events[node.name].append(len(self.events) - 1)
+
+    def print_node(self, node_name, recursive=False, tab="", writer=None):
+        """
+        Print a node and its events.
+        @param recursive: if True, print nodes that caused the events on this current node.
+        @param tab: Indentation for dependencies.
+        @param writer: function to write to file. If None, use print.
+        """
+        if not writer:
+            writer = self.writer
+        for idx in self.node_events.get(node_name, []):
+            event = self.events[idx]
+            writer(tab + event.to_str())
+            if recursive and event.dep is not None:
+                self.print_node(
+                    event.dep.name, recursive=True, tab="| " + tab, writer=writer
+                )
+
+    def to_dict(self):
+        """
+        Create dict dump on all events.
+        """
+        ret: dict[str, list[str]] = {}
+        for name in self.node_events:
+            ret[name] = []
+            for idx in self.node_events.get(name, []):
+                event = self.events[idx]
+                ret[name].append(event.to_str())
+        return ret
+
+    def print_all(self, writer=None):
+        """
+        Print all nodes in a list.
+        @param writer: function to write to file. If None, use print.
+        """
+        if not writer:
+            writer = self.writer
+        for name in self.node_events:
+            writer(f"Node: {name}:")
+            self.print_node(name, recursive=False, tab="  ", writer=writer)
+
+    def dump(self):
+        """
+        Function to be invoked at the end of the finder execution to printout tracked events specified by the mode.
+        """
+        # dump via trace_structured
+        trace_structured(
+            "artifact",
+            metadata_fn=lambda: {
+                "name": "fx_net_acc_splitter_finder_events",
+                "encoding": "json",
+            },
+            payload_fn=lambda: json.dumps(self.to_dict()),
+        )
+
+        def writeln(f):
+            def fn(x):
+                return f.write(x + "\n")
+
+            return fn
+
+        # Mode 0: no local dump
+        # Mode >=1: Dump all events to file
+        if self.tracker_mode >= 1:
+            with open(self.dump_prefix + ALL_SUFFIX, "w") as f:
+                self.print_all(writeln(f))
+
+        def dump_selected_nodes(nodes):
+            with open(self.dump_prefix + NODES_SUFFIX, "w") as f:
+                for node_name in nodes:
+                    writeln(f"===== Tracking node {node_name} =====")
+                    self.print_node(
+                        node_name, recursive=True, tab="|-", writer=writeln(f)
+                    )
+                    writeln(f"===== End of tracking node {node_name} =====")
+
+        # Mode 2: Dump specific nodes in recursive manner.
+        # Mode 3: Dump all nodes with more than 1 event in recursive manner.
+        if self.tracker_mode == 2 or self.tracker_mode == 3:
+            nodes = (
+                os.environ.get(ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES, "").split(
+                    ","
+                )
+                if self.tracker_mode == 2
+                else [
+                    name for name, events in self.node_events.items() if len(events) > 1
+                ]
+            )
+            dump_selected_nodes(nodes)
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetAccNodesFinder:
+    """
+    Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
+    input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
+
+    I.e. if we have a chain:
+
+    ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
+
+    where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
+
+    This behavior can be turned off by passing allow_non_tensor=True.
+    """
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        operator_support: OperatorSupportBase,
+        allow_non_tensor: bool,
+    ):
+        self.module = module
+        self.operator_support = operator_support
+        self.allow_non_tensor = allow_non_tensor
+        self.acc_nodes: NodeSet = set()
+
+        self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX)
+
+    def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
+        """
+        Transitively excludes nodes from ACC supported set.
+        For every node in the worklist:
+        - removes its downstream ACC nodes from ACC supported set,
+        - if any downstream ACC node produces non-tensor output,
+          then it gets added into the worklist.
+        """
+        while cpu_worklist:
+            node = cpu_worklist.pop(0)
+
+            for user in node.users:
+                if user in self.acc_nodes:
+                    self.acc_nodes.remove(user)
+                    self.tracker.add(user, "acc_del|user_of_new_cpu_node", node)
+                    if not is_node_output_tensor(user):
+                        self.tracker.add(user, "new_cpu_node|non_tensor_output")
+                        cpu_worklist.append(user)
+
+    def reduce_acc_nodes_non_tensor_input(self):
+        """
+        Excludes nodes from ACC supported set that have direct
+        upstream CPU nodes that produce non-tensor outputs.
+        """
+        non_tensor_cpu_nodes: NodeList = []
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+            if node in self.acc_nodes:
+                continue
+            if is_node_output_tensor(node):
+                continue
+            self.tracker.add(node, "new_cpu_node|callable_non_tensor_input")
+            non_tensor_cpu_nodes.append(node)
+
+        self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
+
+    def reduce_acc_nodes_non_tensor_output(self):
+        """
+        Excludes nodes from ACC supported set that produce non-tensor
+        outputs and have downstream CPU nodes.
+        """
+        while True:
+            new_cpu_nodes: NodeList = []
+
+            for acc_node in self.acc_nodes:
+                if is_node_output_tensor(acc_node):
+                    continue
+                for user in acc_node.users:
+                    if user not in self.acc_nodes:
+                        new_cpu_nodes.append(acc_node)
+                        self.tracker.add(
+                            acc_node, "acc_del|non_tensor_output_with_cpu_user", user
+                        )
+                        break
+
+            if not new_cpu_nodes:
+                break
+
+            for new_cpu_node in new_cpu_nodes:
+                self.acc_nodes.remove(new_cpu_node)
+
+            self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
+
+    def __call__(self) -> NodeSet:
+        submodules = dict(self.module.named_modules())
+        self.acc_nodes = set()
+        for n in self.module.graph.nodes:
+            if n.op not in CALLABLE_NODE_OPS:
+                self.tracker.add(n, "init_cpu|not_callable")
+                continue
+            if not self.operator_support.is_node_supported(submodules, n):
+                self.tracker.add(n, "init_cpu|operator_support")
+                continue
+
+            self.tracker.add(n, "init_acc|callable_and_operator_supported")
+            self.acc_nodes.add(n)
+
+        if not self.allow_non_tensor:
+            self.reduce_acc_nodes_non_tensor_input()
+            self.reduce_acc_nodes_non_tensor_output()
+        self.tracker.dump()
+        return self.acc_nodes
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetSplitterInternalError(Exception):
+    pass
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class Subgraph:
+    is_acc: bool
+    nodes: NodeList
+    device_ordinal: Optional[int] = None
+
+
+@compatibility(is_backward_compatible=False)
+class SplitResult(NamedTuple):
+    """
+    Stores the results of the splitter.
+
+    Attributes:
+        split_module: root module after splitting.
+        submodule_inputs: a dict that maps submodule name to its inputs.
+        non_acc_submodule_prefix: the prefix for non acc submodules. For
+            acc submodule the prefix is always "_run_on_acc_".
+    """
+
+    split_module: torch.fx.GraphModule
+    submodule_inputs: dict[str, Any]
+    non_acc_submodule_prefix: str
+
+
+@compatibility(is_backward_compatible=False)
+def generate_inputs_for_submodules(
+    model: torch.nn.Module,
+    inputs: Sequence[Any],
+    target_submodules: Iterable[str],
+    deepcopy: bool = False,
+) -> dict[str, Any]:
+    """
+    Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
+    function doesn't work.
+
+    Args:
+        model: root model.
+        inputs: inputs to the root model.
+        target_submodules: submodules that we want to generate inputs for.
+
+    Returns:
+        A dict that maps from submodule name to its inputs.
+    """
+
+    handles = []
+    results = {}
+    submodule_to_names = {mod: name for name, mod in model.named_modules()}
+
+    def pre_forward(module, module_inputs):
+        results[submodule_to_names[module]] = (
+            copy.deepcopy(module_inputs) if deepcopy else module_inputs
+        )
+
+    for name, mod in model.named_modules():
+        if name in target_submodules:
+            if not isinstance(mod, torch.jit.ScriptModule):
+                handles.append(mod.register_forward_pre_hook(pre_forward))
+
+    def clean_up_handles():
+        for h in handles:
+            h.remove()
+
+    try:
+        with torch.no_grad():
+            model(*inputs)
+    except Exception as e:
+        clean_up_handles()
+        raise e
+
+    clean_up_handles()
+    return results
+
+
+class _SplitterBase:
+    """
+    Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
+    Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
+    Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
+
+    Given the following graph:
+          ==> b ==>
+        //         \\
+       a             d
+        \\         //
+          ==> c ==>
+
+    class SimpleModule(torch.nn.Module):
+        def forward(self, a):
+            b = torch.sin(a)
+            c = torch.cos(a)
+            d = b + c
+            return d
+
+    and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
+    we will get the following split result:
+
+    main:
+    def forward(self, a):
+        run_on_acc_0_0 = self._run_on_acc_0_0(a)
+        getitem = run_on_acc_0_0[0]
+        getitem_1 = run_on_acc_0_0[1]
+        run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
+        return run_on_cpu_1_1
+
+    _run_on_acc_0_0:
+    def forward(self, a):
+        sin_1 = torch.sin(a)
+        cos_1 = torch.cos(a)
+        return (sin_1, cos_1)
+
+    _run_on_cpu_1_1:
+    def forward(self, sin_1, cos_1):
+        add_1 = sin_1 + cos_1
+        return add_1
+    """
+
+    # PCIe bandwidth for the backend, default to 100 GB/s
+    PCIe_BW = 100 * 2**30
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Sequence[Any],
+        operator_support: OperatorSupportBase,
+        settings: _SplitterSettingBase,
+        non_acc_submodule_name: str = "_run_on_cpu_",
+        return_tuple: bool = False,
+        nodes_finder: Optional[FxNetAccNodesFinder] = None,
+    ):
+        """
+        Preprocesses graph before splitting:
+        - finds nodes supported by ACC,
+        - finds fusion groups for ACC nodes having non-tensor IO,
+        - builds a graph of direct dependencies,
+        - builds a map of fused nodes to their fusions.
+        As a result we get self.acc_nodes, self.deps and self.fusions.
+        """
+        assert isinstance(module, torch.fx.GraphModule)
+
+        self.module = module
+        ShapeProp(self.module).propagate(*sample_input)
+
+        self.settings = settings
+        self.operator_support = operator_support
+        self.sample_input = sample_input
+        if nodes_finder is None:
+            nodes_finder = FxNetAccNodesFinder(
+                self.module, self.operator_support, self.settings.allow_non_tensor
+            )
+        self.acc_nodes = nodes_finder()
+
+        if self.settings.skip_fusion:
+            self.fusions = {}
+        else:
+            self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
+
+        # Modify deps to add more deps for fused nodes
+        self.deps = self.find_deps()
+        self.update_deps_for_fusions()
+
+        self.non_acc_submodule_name = non_acc_submodule_name
+        self._node_submodule_map: dict[str, str] = {}
+        self._return_tuple = return_tuple
+
+        self.tags: list[str] = []
+
+    # ===============================================================
+    # Helpers for ctor and initial state
+    # ===============================================================
+
+    def get_node_submodule_map(self) -> dict[str, str]:
+        """Returns a map from node name to submodule name, e.g.
+        node: main_module_impl_impl_over_arch_unary_multiple_embedding
+          _pooling_embedding_pooling_sparse_entity_equivalence_key
+          _proxy_embedding_bag
+        maps to submodule name of: _run_on_acc_1
+        """
+        return self._node_submodule_map
+
+    def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
+        """
+        Builds a graph of node dependencies. Leaf nodes don't have any
+        dependencies and the "output" node doesn't have nodes depending on it.
+
+        Resulting graph has only direct dependencies, i.e. there are no
+        transitive dependencies.
+        """
+        deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            for user in node.users:
+                if user.op != "output":
+                    deps[user].add(node)
+        return deps
+
+    def update_deps_for_fusions(self):
+        """
+        Updates graph of dependencies so that:
+        - nodes from the same fusion depend on the same set of outer nodes,
+        - outer nodes depending on a fusion depend on all nodes in that fusion.
+        """
+        for node in self.fusions:
+            fusion = self.fusions[node]
+            for fused_neighbor in fusion:
+                self.deps[node].update(self.deps[fused_neighbor] - fusion)
+
+                for user in fused_neighbor.users:
+                    if user not in fusion:
+                        self.deps[user].add(node)
+
+    # ===============================================================
+    # Helpers for preview
+    # ===============================================================
+
+    def _lower_model_to_backend(
+        self, mod: torch.fx.GraphModule, inputs: Tensors
+    ) -> torch.nn.Module:
+        """
+        Lower the model to a backend.
+        """
+
+        return mod
+
+    def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str:
+        """
+        When an error occurs during lowering or running the lowered mod, we use this
+        function to find culprits in the `mod` that causes the error.
+        """
+
+        return "Unable to find a culprit because _find_culprit() function is not implemented."
+
+    def _draw_graph_based_on_node_support(
+        self, mod: torch.fx.GraphModule, supported_nodes: NodeList
+    ):
+        color_map = {
+            "default": "AliceBlue",
+            "supported": "chartreuse1",
+            "unsupported": "crimson",
+        }
+
+        class CustomDrawer(FxGraphDrawer):
+            def _get_node_style(self, node):
+                template = super()._get_node_style(node)
+                if node in supported_nodes:
+                    template["fillcolor"] = color_map["supported"]
+                elif node.op in CALLABLE_NODE_OPS:
+                    template["fillcolor"] = color_map["unsupported"]
+                else:
+                    template["fillcolor"] = color_map["default"]
+
+                return template
+
+        drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
+        dot_graph = drawer.get_main_dot_graph()
+        # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
+        dot_graph.write_raw("node_support.dot")  # type: ignore[attr-defined]
+
+    def node_support_preview(self, dump_graph: bool = False):
+        submodules = dict(self.module.named_modules())
+
+        supported_nodes: NodeList = []
+        supported_node_types = defaultdict(set)
+        unsupported_node_types = defaultdict(set)
+
+        def get_dtype(arg):
+            tensor_meta = arg.meta.get("tensor_meta")
+            return getattr(tensor_meta, "dtype", None)
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            target = get_node_target(submodules, node)
+
+            # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
+            arg_dtypes = [
+                get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
+                for arg in node.args
+            ]
+
+            # Find last non-None element. If all elements are None, return max_len.
+            last_index = len(arg_dtypes) - next(
+                (
+                    i
+                    for i, dtype in enumerate(reversed(arg_dtypes))
+                    if dtype is not None
+                ),
+                len(arg_dtypes),
+            )
+
+            # Strip None elements at the end.
+            arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
+            kwarg_dtypes_tuple = tuple(
+                (k, get_dtype(arg))
+                for k, arg in node.kwargs.items()
+                if isinstance(arg, torch.fx.Node)
+            )
+
+            if self.operator_support.is_node_supported(submodules, node):
+                supported_nodes.append(node)
+                supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
+            else:
+                unsupported_node_types[target].add(
+                    (arg_dtypes_tuple, kwarg_dtypes_tuple)
+                )
+
+        if dump_graph:
+            self._draw_graph_based_on_node_support(self.module, supported_nodes)
+
+        reports = "\nSupported node types in the model:\n"
+        for t, dtypes in supported_node_types.items():
+            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
+                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
+
+        reports += "\nUnsupported node types in the model:\n"
+        for t, dtypes in unsupported_node_types.items():
+            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
+                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
+
+        print(reports)
+
+        # Return reports for testing purpose
+        return reports
+
+    def split_preview(self, dump_graph: bool = False):
+        reports = ""
+        subgraphs = self.put_nodes_into_subgraphs()
+        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
+        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
+        reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
+        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
+
+        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
+        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
+        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
+        reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
+        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
+
+        for i, subgraph in enumerate(subgraphs):
+            reports += (
+                f"_run_on_acc_{i}: "
+                if subgraph.is_acc
+                else f"{self.non_acc_submodule_name}{i}: "
+            )
+            reports += f"{len(subgraph.nodes)} node(s)\n"
+
+        self.tag(subgraphs)
+        split_mod = self.split(remove_tag=True)
+        split_mod.eval()
+
+        if dump_graph:
+            drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
+            dot_graphs = drawer.get_all_dot_graphs()
+            for name, dot_graph in dot_graphs.items():
+                # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
+                dot_graph.write_raw(f"{name}.dot")  # type: ignore[attr-defined]
+
+        max_qps: float = self.PCIe_BW
+        bottleneck_module = ""
+
+        for node in split_mod.graph.nodes:
+            if node.op == "call_module" and "acc" in node.target:
+                reports += f"\nProcessing acc submodule {node.target}\n"
+
+                submod = getattr(split_mod, node.target)
+
+                def get_submod_inputs(main_mod, submod, example_inputs):
+                    sub_inputs = None
+
+                    def get_inputs(self, inputs):
+                        nonlocal sub_inputs
+                        sub_inputs = inputs
+
+                    handle = submod.register_forward_pre_hook(get_inputs)
+                    main_mod(*example_inputs)
+                    handle.remove()
+                    return sub_inputs
+
+                submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input)
+                ShapeProp(submod).propagate(*submod_inputs)
+
+                total_input_bytes = 0
+                total_output_bytes = 0
+
+                reports += "Checking inputs...\n"
+                for n in submod.graph.nodes:
+                    if n.op == "placeholder":
+                        if not is_node_output_tensor(n):
+                            reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
+                        else:
+                            total_input_bytes += get_size_of_node(submod, n)[0]
+                    if n.op == "output":
+                        output_node = n
+
+                reports += "Checking outputs...\n"
+
+                def get_bytes(node: torch.fx.Node):
+                    nonlocal total_output_bytes
+                    nonlocal reports
+                    if not is_node_output_tensor(node):
+                        reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
+                    else:
+                        total_output_bytes += get_size_of_node(submod, node)[0]
+
+                map_arg(output_node.args, get_bytes)  # type: ignore[possibly-undefined]
+                qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
+                reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
+                reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
+
+                if qps < max_qps:
+                    max_qps = qps
+                    bottleneck_module = node.target
+
+                try:
+                    lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
+                except RuntimeError:
+                    reports += "Run into an error during lowering!\n"
+                    reports += self._find_culprit(submod, submod_inputs)
+                    continue
+
+                try:
+                    lowered_submod(*submod_inputs)
+                except RuntimeError:
+                    reports += "Run into an error during inference!\n"
+                    reports += self._find_culprit(submod, submod_inputs)
+                else:
+                    reports += "Lowering and running succeed!\n"
+
+        reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
+        reports += f" bottleneck is submodule {bottleneck_module}."
+        print(reports)
+
+        # return the reports for testing purposes
+        return reports
+
+    # ===============================================================
+    # Helpers for extend_acc_subgraph() method
+    # ===============================================================
+
+    def find_reverse_deps(
+        self, tag_id: Optional[int] = None
+    ) -> dict[torch.fx.Node, NodeSet]:
+        """
+        Builds reversed topological node dependencies, if tag_id is specified,
+        we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
+        """
+        result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            for user in node.users:
+                if user.op not in CALLABLE_NODE_OPS:
+                    continue
+
+                if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
+                    result[node].add(user)
+
+        return result
+
+    def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
+        processed_node = set()
+
+        for node, fusion in self.fusions.items():
+            if node in processed_node:
+                continue
+
+            new_dep = set()
+
+            # Create a new dependency set which include all the
+            # dependencies of the nodes in the fusion group
+            for n in fusion:
+                new_dep.update(deps[n])
+
+            # Exclude nodes in the fusion
+            new_dep.difference_update(fusion)
+
+            # Update dependency
+            for n in fusion:
+                deps[n] = new_dep
+
+                for arg in n.all_input_nodes:
+                    if arg not in fusion:
+                        deps[arg].update(fusion)
+
+                processed_node.add(n)
+
+    def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
+        """
+        Finds parent nodes of the `tag` subgraph.
+
+        Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
+        and is not a placeholder, we consider it as the parent node of the subgraph.
+        """
+        parent_nodes = set()
+
+        for node in self.module.graph.nodes:
+            if node.op in CALLABLE_NODE_OPS and node.tag == tag:
+                for arg in node.all_input_nodes:
+                    if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
+                        parent_nodes.add(arg)
+
+        return parent_nodes
+
+    def extend_acc_subgraph(self, tag: str):
+        """
+        Extend the acc subgraph with `tag` going the reversed topological direction.
+        """
+        # Dict that maps node to its users and ignore users that
+        # are in the subgraph that has greater tag
+        deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1]))
+        self.update_reverse_deps_for_fusions(deps)
+
+        # Parent nodes of the subgraph
+        parent_nodes = self.find_parent_nodes_of_subgraph(tag)
+
+        visited_nodes: NodeSet = set()
+
+        while parent_nodes:
+            node = None
+
+            # Find a acc node that depends on visited nodes only
+            for n in parent_nodes:
+                if deps[n] <= visited_nodes and n in self.acc_nodes:
+                    node = n
+                    break
+
+            if node is None:
+                break
+
+            # Put the node into `tag` subgraph
+            node.tag = tag  # type: ignore[attr-defined]
+            parent_nodes.remove(node)
+            visited_nodes.add(node)
+
+            # If node is in a fusion group, add all fusion buddies to parent nodes
+            if node in self.fusions:
+                for fusion_node in self.fusions[node]:
+                    if fusion_node not in visited_nodes:
+                        parent_nodes.add(fusion_node)
+
+            # Add inputs of the node to parent nodes
+            for arg in node.all_input_nodes:
+                if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
+                    parent_nodes.add(arg)
+
+    # ===============================================================
+    # Helpers for split() method
+    # ===============================================================
+
+    def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
+        """
+        Finds nodes that consume module inputs or get_attr nodes.
+        """
+        starter_cpu_nodes: NodeSet = set()
+        starter_acc_nodes: NodeSet = set()
+        for node in self.module.graph.nodes:
+            # edge case, call_function, but with no dependencies
+            if node.op == "call_function" and len(node.all_input_nodes) == 0:
+                if node in self.acc_nodes:
+                    starter_acc_nodes.add(node)
+                else:
+                    starter_cpu_nodes.add(node)
+
+            if node.op not in {"placeholder", "get_attr"}:
+                continue
+
+            for user in node.users:
+                if user in self.acc_nodes:
+                    starter_acc_nodes.add(user)
+                else:
+                    starter_cpu_nodes.add(user)
+
+        return starter_cpu_nodes, starter_acc_nodes
+
+    def put_nodes_into_subgraphs(self) -> list[Subgraph]:
+        # We start graph traversal from leaf nodes
+        current_cpu_nodes, current_acc_nodes = self.starter_nodes()
+        visited_nodes: NodeSet = set()
+
+        # Determine which subgraph to start from based on which subgraph has
+        # 0-dep node
+        acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
+
+        current_subgraph_nodes: NodeList = []
+
+        # Result accumulator
+        subgraphs: list[Subgraph] = []
+        while current_cpu_nodes or current_acc_nodes:
+            # Find the first node that should belong to the current subgraph and has all dependencies resolved
+            current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
+            node = next(
+                (n for n in current_nodes if self.deps[n] <= visited_nodes),
+                None,
+            )
+
+            # If nothing was found, then it's time to flip the mode and start a new subgraph
+            if node is None:
+                if not current_subgraph_nodes:
+                    raise FxNetSplitterInternalError("Subgraph can't be empty")
+
+                subgraphs.append(
+                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
+                )
+                acc_subgraph = not acc_subgraph
+                current_subgraph_nodes = []
+                continue
+
+            current_nodes.remove(node)
+            visited_nodes.add(node)
+            current_subgraph_nodes.append(node)
+
+            # Add fusion buddies
+            if node in self.fusions:
+                if node in self.acc_nodes:
+                    current_acc_nodes.update(self.fusions[node] - visited_nodes)
+                else:
+                    current_cpu_nodes.update(self.fusions[node] - visited_nodes)
+
+            # Put depending nodes into the queue
+            for user in node.users:
+                if user.op not in CALLABLE_NODE_OPS:
+                    continue
+
+                # Add downstream nodes
+                if user in self.acc_nodes:
+                    current_acc_nodes.add(user)
+                else:
+                    current_cpu_nodes.add(user)
+
+        # Check if the last subgraph was not created
+        if current_subgraph_nodes:
+            subgraphs.append(
+                Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
+            )
+
+        if not subgraphs:
+            raise FxNetSplitterInternalError("Couldn't create subgraphs")
+
+        return subgraphs
+
+    def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
+        """
+        This pass finds ACC submodules with less than specified size and merges
+        them with adjacent CPU submodules.
+        """
+        result: list[Subgraph] = []
+        for subgraph in subgraphs:
+            if subgraph.is_acc:
+                if len(subgraph.nodes) >= self.settings.min_acc_module_size:
+                    result.append(subgraph)
+                else:
+                    print(
+                        "Eliminating acc subgraph because it's smaller than the threshold: "
+                        f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
+                    )
+                    if result:
+                        result[-1].nodes.extend(subgraph.nodes)
+                    else:
+                        subgraph.is_acc = False
+                        result.append(subgraph)
+            else:
+                if result and not result[-1].is_acc:
+                    result[-1].nodes.extend(subgraph.nodes)
+                else:
+                    result.append(subgraph)
+        return result
+
+    def tag(self, subgraphs: list[Subgraph]):
+        self.tags = []
+        for subgraph in subgraphs:
+            tag = (
+                f"_run_on_acc_{len(self.tags)}"
+                if subgraph.is_acc
+                else f"{self.non_acc_submodule_name}{len(self.tags)}"
+            )
+            self.tags.append(tag)
+            for node in subgraph.nodes:
+                if hasattr(node, "tag"):
+                    raise FxNetSplitterInternalError(f"Node {node} was already tagged")
+
+                node.tag = tag  # type: ignore[attr-defined]
+                self._node_submodule_map[node.name] = tag
+
+    def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
+        split_module = split_by_tags(
+            self.module, self.tags, return_tuple=self._return_tuple
+        )
+        if remove_tag:
+            for node in self.module.graph.nodes:
+                if hasattr(node, "tag"):
+                    del node.tag
+        return split_module  # type: ignore[return-value]
+
+    def __call__(self) -> torch.fx.GraphModule:
+        subgraphs = self.put_nodes_into_subgraphs()
+        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
+        acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
+        non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
+        print(
+            f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs"
+        )
+        self.tag(subgraphs)
+        return self.split()
+
+    def generate_split_results(self) -> SplitResult:
+        split_module = self()
+        submodule_names = []
+        for name, _mod in split_module.named_children():
+            submodule_names.append(name)
+        if (
+            self.settings.max_acc_splits > 0
+            and len(submodule_names) > self.settings.max_acc_splits
+        ):
+            raise ValueError(
+                "Cannot fulfill max_acc_splits limit. "
+                "This may cause split fragmentation and "
+                "result in performance issues."
+            )
+
+        submodule_inputs = generate_inputs_for_submodules(
+            split_module, self.sample_input, submodule_names
+        )
+        return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tools_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tools_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a8f0df8449749167c4ec3dedaf719d78fad577
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tools_common.py
@@ -0,0 +1,390 @@
+# mypy: allow-untyped-defs
+import collections
+import heapq
+import operator
+from collections.abc import Mapping
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.node import _get_qualified_name
+
+
+__all__ = [
+    "get_acc_ops_name",
+    "get_node_target",
+    "is_node_output_tensor",
+    "FxNetAccFusionsFinder",
+    "legalize_graph",
+    "stable_topological_sort",
+]
+
+Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
+TensorOrTensors = Union[torch.Tensor, Tensors]
+NodeList = list[torch.fx.Node]
+NodeSet = set[torch.fx.Node]
+Names = list[str]
+CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
+
+
+@compatibility(is_backward_compatible=False)
+def get_acc_ops_name(k):
+    if isinstance(k, str):
+        return k
+    elif k.__module__ and "acc_ops" in k.__module__:
+        return f"acc_ops.{k.__name__}"
+    else:
+        module = k.__module__.replace(
+            "torch._ops", "torch.ops"
+        )  # WAR for bug in how torch.ops assigns module
+        return f"{module if module else ''}.{k.__name__}"
+
+
+@compatibility(is_backward_compatible=False)
+def get_node_target(
+    submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
+) -> str:
+    """
+    Given a `node` returns its target typename.
+
+    For "call_method" node, return node.target which is the name of that method being called.
+    This could potential lead to conflict but should be okay because normally it's on a tensor.
+
+    For "call_function" node, return typename of node.target.
+
+    For "call_module" node, return typename of the module that node.target point to.
+
+    If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
+    "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
+    """
+
+    assert node.op in CALLABLE_NODE_OPS, (
+        "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
+    )
+
+    if node.op == "call_module":
+        assert isinstance(node.target, str)
+        submod = submodules[node.target]
+        submod_type = getattr(submod, "_base_class_origin", type(submod))
+        return get_acc_ops_name(submod_type)
+    elif node.op == "call_function":
+        target: Any = node.target
+        return (
+            f"acc_ops.{target.__name__}"
+            if target.__module__ is not None and "acc_ops" in target.__module__
+            else _get_qualified_name(target)
+        )
+    else:
+        assert isinstance(node.target, str)
+        return node.target
+
+
+@compatibility(is_backward_compatible=False)
+def is_node_output_tensor(node: torch.fx.Node) -> bool:
+    """Checks if the node output produces a Tensor or not.
+
+    NOTE: This requires to run `ShapeProp` on the containing fx graph before
+    calling this function. This is because it works by checking the `type`
+    metadata on the node. This metadata is produced by the `ShapeProp`.
+    """
+    type_ = node.meta.get("type", None)
+    return type_ is not None and issubclass(type_, torch.Tensor)
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetAccFusionsFinder:
+    """
+    Finds groups of connected ACC nodes that pass non-tensor data between each other.
+    Such groups are called fusion groups.
+    """
+
+    def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
+        self.module = module
+        self.nodes = list(module.graph.nodes)
+        self.acc_nodes = acc_nodes
+
+    @dataclass
+    class FusionGroup:
+        # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
+        top_node_idx: int
+
+        # Nodes in this fusion group.
+        nodes: NodeSet
+
+        # Inputs to this fusion group.
+        inputs: NodeSet
+
+        # Nodes that in the fusion group that haven't been processed yet.
+        nodes_need_process: NodeSet
+
+        def add_node(self, node):
+            """
+            Add a node to fusion group.
+            """
+            if node in self.nodes:
+                return
+
+            self.nodes_need_process.add(node)
+            self.nodes.add(node)
+            self.inputs.discard(node)
+            self.inputs.update(
+                {
+                    n
+                    for n in node.all_input_nodes
+                    if n.op in CALLABLE_NODE_OPS and n not in self.nodes
+                }
+            )
+
+    def recursive_add_node(
+        self,
+        fusion_group: "FxNetAccFusionsFinder.FusionGroup",
+        inputs: Union[NodeSet, NodeList],
+        visited: Optional[NodeSet] = None,
+    ):
+        """
+        Start from inputs and going reverse topological order. If any upstream node
+        is in the fusion group, add all the nodes in this path to fusion group.
+        """
+        for arg in inputs:
+            # skip the node if already seen
+            if visited is not None:
+                if arg in visited:
+                    continue
+                visited.add(arg)
+
+            # Skip placeholder and get_attr because they won't be in the fusion group.
+            if arg.op not in CALLABLE_NODE_OPS:
+                continue
+
+            # If the node has smaller idx, it's already an upstream node of the fusion
+            # group. We don't need to check it anymore.
+            if self.nodes.index(arg) < fusion_group.top_node_idx:
+                continue
+
+            # If the node is in the fusion group, return True.
+            if arg in fusion_group.nodes:
+                return True
+
+            # Check the upstream nodes of the node, if any of them is in the fusion group
+            # we'll add this node to fusion group and return True.
+            if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
+                fusion_group.add_node(arg)
+                return True
+
+        return False
+
+    def __call__(self) -> dict[torch.fx.Node, NodeSet]:
+        result: dict[torch.fx.Node, NodeSet] = {}
+        acc_nodes = list(self.acc_nodes)
+
+        for node in acc_nodes:
+            if node in result:
+                continue
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+            if "tensor_meta" in node.meta:
+                continue
+            if node not in self.acc_nodes:
+                continue
+
+            fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
+                top_node_idx=self.nodes.index(node),
+                nodes={node},
+                inputs=set(node.all_input_nodes),
+                nodes_need_process={node},
+            )
+            while fusion_group.nodes_need_process:
+                node = fusion_group.nodes_need_process.pop()
+                self.recursive_add_node(
+                    fusion_group,
+                    fusion_group.inputs,
+                    visited=set(),
+                )
+
+                # Optionally add downstream nodes
+                if "tensor_meta" not in node.meta:
+                    for user in node.users:
+                        if user.op not in CALLABLE_NODE_OPS:
+                            continue
+                        if user in fusion_group.nodes:
+                            continue
+
+                        fusion_group.add_node(user)
+                        self.recursive_add_node(
+                            fusion_group,
+                            fusion_group.inputs,
+                            visited=set(),
+                        )
+
+                # Add some upstream nodes
+                for arg in node.all_input_nodes:
+                    if arg.op not in CALLABLE_NODE_OPS:
+                        continue
+                    if "tensor_meta" in arg.meta:
+                        continue
+                    if arg in fusion_group.nodes:
+                        continue
+
+                    fusion_group.add_node(arg)
+                    fusion_group.top_node_idx = min(
+                        fusion_group.top_node_idx, self.nodes.index(arg)
+                    )
+                    self.recursive_add_node(
+                        fusion_group,
+                        fusion_group.inputs,
+                        visited=set(),
+                    )
+
+            if not (set(fusion_group.nodes) <= self.acc_nodes):
+                self.acc_nodes -= fusion_group.nodes
+            else:
+                for n in fusion_group.nodes:
+                    result[n] = fusion_group.nodes
+
+        return result
+
+
+@compatibility(is_backward_compatible=False)
+def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Replace the graph of the given GraphModule with one that contains the same nodes as the
+    original, but in topologically sorted order.
+
+    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
+    order of its input GraphModule, so that this order is restored before further transformation.
+
+    Arguments:
+        gm: The graph module to topologically sort. It is modified in-place.
+
+    Returns:
+        The graph module in-place sorted
+
+    Warning:
+        This topological sort is NOT stable, it will NOT preserve the original node order.
+        If you need a stable topological sort, use stable_topological_sort instead.
+    """
+
+    # These operators are used for making runtime assertions before any
+    # data-dependent operators occur. We want to prioritize sorting these to
+    # ensure that these assertions appear before any data-dependent operations
+    # in the graph.
+    PRIORITIZED_OPS = [
+        operator.add,
+        operator.mul,
+        operator.sub,
+        operator.floordiv,
+        operator.truediv,
+        operator.mod,
+        operator.le,
+        operator.lt,
+        operator.ge,
+        operator.gt,
+        operator.eq,
+        operator.ne,
+        torch.ops.aten.sym_constrain_range.default,
+        torch.ops.aten.sym_constrain_range_for_size.default,
+        torch.ops.aten._assert_async.msg,
+        torch.ops.aten.scalar_tensor.default,
+        torch.ops.aten._assert_scalar.default,
+    ]
+
+    indeg = dict.fromkeys(gm.graph.nodes, 0)
+    new_graph = torch.fx.Graph()
+    # Track how many unfulfilled dependencies each node has
+    for node in gm.graph.nodes:
+        for user in node.users:
+            indeg[user] += 1
+    queue: collections.deque = collections.deque()
+    # Add all nodes with no dependencies to the queue
+    for node in gm.graph.nodes:
+        if indeg[node] == 0:
+            queue.append(node)
+    env: dict[torch.fx.Node, torch.fx.Node] = {}
+    # Pop nodes from the queue, and add nodes that have had all their
+    # dependencies fulfilled
+    while len(queue) > 0:
+        cur = queue.popleft()
+        env[cur] = new_graph.node_copy(cur, lambda x: env[x])
+        for user in cur.users:
+            indeg[user] -= 1
+            if indeg[user] == 0:
+                if user.op == "call_function" and user.target in PRIORITIZED_OPS:
+                    queue.appendleft(user)
+                else:
+                    queue.append(user)
+    # If the new graph's size is not as large as the old one, then there must be
+    # a cycle (i.e. some node's dependencies were not satisfied.)
+    if len(new_graph.nodes) < len(gm.graph.nodes):
+        raise RuntimeError(
+            f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
+        )
+    new_graph._codegen = gm.graph._codegen
+    gm.graph = new_graph
+    return gm
+
+
+@compatibility(is_backward_compatible=False)
+def stable_topological_sort(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Replace the graph of the given GraphModule with one that contains the same nodes as the
+    original, but in topologically sorted order while preserving the original node order
+    as much as possible.
+
+    This function performs a stable topological sort where nodes appear in an order that:
+    1. Respects data dependencies (topological ordering)
+    2. Preserves the original node order when there are no dependency constraints
+
+    The algorithm uses Kahn's algorithm with a priority queue: nodes with all dependencies
+    satisfied are added to a min-heap, ordered by their original position. This ensures
+    we always process the earliest node in the original order among ready nodes.
+
+    Arguments:
+        gm: The graph module to topologically sort. It is modified in-place.
+
+    Returns:
+        The graph module in-place sorted
+    """
+    indeg = dict.fromkeys(gm.graph.nodes, 0)
+    new_graph = torch.fx.Graph()
+
+    # Build node to original index mapping
+    node_to_id: dict[torch.fx.Node, int] = {
+        node: idx for idx, node in enumerate(gm.graph.nodes)
+    }
+
+    # Track how many unfulfilled dependencies each node has
+    for node in gm.graph.nodes:
+        for user in node.users:
+            indeg[user] += 1
+
+    # Priority queue: (original_index, node)
+    # Use min-heap to always process the node with smallest original index
+    ready_queue: list[tuple[int, torch.fx.Node]] = []
+    for node in gm.graph.nodes:
+        if indeg[node] == 0:
+            heapq.heappush(ready_queue, (node_to_id[node], node))
+
+    env: dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Process nodes
+    while ready_queue:
+        # Pop node with smallest original index
+        _, cur = heapq.heappop(ready_queue)
+        env[cur] = new_graph.node_copy(cur, lambda x: env[x])
+
+        # Update in-degrees and add newly ready nodes
+        for user in cur.users:
+            indeg[user] -= 1
+            if indeg[user] == 0:
+                heapq.heappush(ready_queue, (node_to_id[user], user))
+
+    # Check if all nodes were processed
+    assert len(new_graph.nodes) == len(gm.graph.nodes), (
+        f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
+    )
+
+    new_graph._codegen = gm.graph._codegen
+    gm.graph = new_graph
+    return gm
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h
new file mode 100644
index 0000000000000000000000000000000000000000..9fd81d3273ce57027528b65071917fa80f4729a0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h
@@ -0,0 +1,486 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+// This file contains helper functions for batching rules.
+
+namespace at::functorch {
+
+TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
+TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
+
+TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
+
+Tensor moveBatchDimToFront(Tensor tensor, std::optional maybe_batch_dim);
+int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional maybe_batch_dim);
+int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional maybe_batch_dim);
+std::optional valIfNonempty(std::optional maybe_empty, int64_t new_val);
+int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
+VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
+
+void vmapIncompatibleInplaceError(const char* schema_name);
+
+Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional has_bdim, int64_t logical_rank);
+
+void check_randomness(RandomnessType randomness);
+void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
+
+inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
+  if (has_bdim) {
+    return tensor;
+  }
+  const auto sizes = tensor.sym_sizes();
+  SymDimVector expanded_shape;
+  expanded_shape.reserve(sizes.size());
+  expanded_shape.emplace_back(std::move(batch_size));
+  expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
+  return tensor.expand_symint(expanded_shape);
+}
+
+#define VMAP_SUPPORT(op, batch_rule) \
+  m.impl(#op, op ## _generated_plumbing);
+
+#define VMAP_SUPPORT2(op, overload, batch_rule) \
+  m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing);
+
+#define OP_DECOMPOSE(op)  m.impl(#op, static_cast(native::op));
+#define OP_DECOMPOSE2(op, overload)  m.impl(#op"."#overload, static_cast(native::op));
+
+// DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
+template 
+struct BasicUnaryBatchRuleHelper;
+
+template 
+struct BasicUnaryBatchRuleHelper> {
+  static std::tuple> apply(
+      const Tensor& tensor,
+      std::optional batch_dim,
+      T... extra_args) {
+    return std::make_tuple(Func(tensor, std::forward(extra_args)...), batch_dim);
+  }
+};
+
+// USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
+// INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
+// It is important that this macro is not passed a function pointer!!
+#define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
+    BasicUnaryBatchRuleHelper<\
+      decltype(&fn),\
+      &fn,\
+      c10::guts::function_traits::parameter_types>::apply)
+
+#define UNARY_POINTWISE(op) \
+  VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
+
+template 
+struct VariadicBdimsBatchRuleHelper;
+
+template 
+struct VariadicBdimsBatchRuleHelper> {
+  static std::tuple> apply(
+      const Tensor& tensor,
+      std::optional batch_dim,
+      T... extra_args) {
+    auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
+    return std::make_tuple(Func(tensor_, std::forward(extra_args)...), 0);
+  }
+};
+
+// USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
+// INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
+// It is important that this macro is not passed a function pointer!!
+#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
+    VariadicBdimsBatchRuleHelper<\
+      decltype(&fn),\
+      &fn,\
+      c10::guts::function_traits::parameter_types>::apply)
+
+#define VARIADIC_BDIMS(op) \
+  VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
+
+#define VARIADIC_BDIMS2(op, overload) \
+  VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
+
+template
+void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  const auto num_arguments = schema.arguments().size();
+
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
+
+  int64_t cur_level = maybe_layer->layerId();
+
+  auto orig_arguments = torch::jit::last(*stack, num_arguments);
+  if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
+    op.callBoxed(stack);
+    return;
+  }
+
+  auto arguments = torch::jit::pop(*stack, num_arguments);
+  std::vector>> tensor_inputs;
+  std::vector tensor_pos;
+  for (const auto idx : c10::irange(0, num_arguments)) {
+    const auto& ivalue = arguments[idx];
+    if (ivalue.isTensor()) {
+      auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
+      tensor_inputs.emplace_back(std::move(tensor_value), tensor_bdim);
+      tensor_pos.push_back(static_cast(idx));
+    }
+  }
+  Func(tensor_inputs);
+
+  size_t tensor_idx = 0;
+  TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
+  for (const auto arg_idx : c10::irange(0, num_arguments)) {
+    if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
+      torch::jit::push(stack, arguments[arg_idx]);
+    } else {
+      TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
+      torch::jit::push(stack, tensor_inputs[tensor_idx].first);
+      tensor_idx++;
+    }
+  }
+
+  op.callBoxed(stack);
+  const auto returns = torch::jit::pop(*stack, num_returns);
+  for (const auto& ret : returns) {
+    if (ret.isTensor()) {
+      torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
+    } else {
+      TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
+    }
+  }
+}
+
+inline void handle_pointwise_ops(std::vector>> &tensor_inputs) {
+  int64_t out_logical_rank = 0;
+  for (auto& tensor_input : tensor_inputs) {
+    int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
+    out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
+  }
+  for (auto& tensor_input: tensor_inputs) {
+    tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
+    tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
+  }
+}
+
+#define POINTWISE_BOXED(op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction>());
+
+#define POINTWISE_BOXED2(op, overload) \
+  m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction>());
+
+inline void handle_variadic_bdims(std::vector>> &tensor_inputs) {
+  for (auto & tensor_input : tensor_inputs) {
+    tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
+  }
+}
+
+#define VARIADIC_BDIMS_BOXED(op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction>());
+
+using UnpackedBatchedTensor = std::tuple>;
+
+inline void find_and_unpack_tensors(
+    const torch::jit::Stack* stack,
+    int64_t num_args,
+    int64_t cur_level,
+    SmallVector* tensors,
+    SmallVector* tensors_pos,
+    int64_t* batch_size) {
+
+  int64_t computed_batch_size = -1;
+  int64_t args_begin = static_cast(stack->size()) - num_args;
+
+  for (const auto idx : c10::irange(0, num_args)) {
+    const auto& ivalue = (*stack)[args_begin + idx];
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
+    const auto& [tensor_value, tensor_bdim] = unpacked;
+    if (tensor_bdim.has_value()) {
+      auto candidate_batch_size = tensor_value.size(*tensor_bdim);
+      if (computed_batch_size == -1) {
+        computed_batch_size = candidate_batch_size;
+      }
+      TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
+    }
+
+    tensors->push_back(std::move(unpacked));
+    tensors_pos->push_back(idx);
+  }
+  TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
+  *batch_size = computed_batch_size;
+}
+
+inline void boxed_existing_bdim_all_batch_rule(
+    const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  const auto num_arguments = static_cast(schema.arguments().size());
+
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  const auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
+
+  const auto arguments = torch::jit::last(stack, num_arguments);
+  if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
+    op.callBoxed(stack);
+    return;
+  }
+
+  int64_t args_begin = static_cast(stack->size()) - num_arguments;
+  SmallVector tensor_inputs;
+  SmallVector tensor_pos;
+  int64_t batch_size = 0;
+  // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
+  int64_t cur_level = maybe_layer->layerId();
+
+  find_and_unpack_tensors(
+      stack, num_arguments, cur_level,
+      &tensor_inputs, &tensor_pos, &batch_size);
+
+  // for each tensor, ensure it has a bdim and reshape it.
+  for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
+    const auto& [value, bdim] = tensor_inputs[tensor_idx];
+    auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
+    (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(bdim.value_or(0), 0, value_);
+  }
+
+  op.callBoxed(stack);
+
+  for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
+    const auto& ret = (*stack)[idx];
+    TORCH_INTERNAL_ASSERT(ret.isTensor(),
+        "This boxed batching rule does not currently support ops that return non-tensor values");
+    (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
+  }
+}
+
+// Use when all tensors arguments accept one (normal) batch dim.
+// This batching rule expands the batch dim on all Tensors, reshapes it into
+// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
+// This is not the most efficient thing; if there are alternatives, please try
+// to use them. Use this only as a last resort.
+#define EXISTING_BDIM_ALL_BOXED(op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction());
+
+template 
+inline void boxed_all_tensors_have_optional_bdim(
+    const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  const auto num_arguments = schema.arguments().size();
+
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
+  int64_t cur_level = maybe_layer->layerId();
+
+  const auto arguments = torch::jit::last(stack, num_arguments);
+  if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
+    op.callBoxed(stack);
+    return;
+  }
+
+  int64_t args_begin = static_cast(stack->size() - num_arguments);
+  SmallVector tensor_inputs;
+  SmallVector tensor_pos;
+  int64_t batch_size = 0;
+
+  find_and_unpack_tensors(
+      stack, static_cast(num_arguments), cur_level,
+      &tensor_inputs, &tensor_pos, &batch_size);
+
+  std::optional is_no_batch_dim_case;
+
+  for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
+    const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
+    auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
+    const auto logical_rank = rankWithoutBatchDim(value, bdim);
+
+    if (!is_no_batch_dim_case.has_value()) {
+      is_no_batch_dim_case = (logical_rank == feature_rank);
+    }
+    auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
+    if (!bdim.has_value()) {
+      bdim = 0;
+    }
+    if (*is_no_batch_dim_case) {
+      TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
+      value_ = moveBatchDimToFront(value_, bdim);
+      if (tensor_idx == contig_tensor_index) {
+        value_ = value_.contiguous();
+      }
+      (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
+      continue;
+    }
+    TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
+    value_ = reshape_dim_into(*bdim, 0, value_);
+    if (tensor_idx == contig_tensor_index) {
+      value_ = value_.contiguous();
+    }
+    (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
+  }
+
+  op.callBoxed(stack);
+
+  for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
+    const auto& ret = (*stack)[idx];
+    TORCH_INTERNAL_ASSERT(ret.isTensor(),
+        "This boxed batching rule does not currently support ops that return non-tensor values");
+    if (*is_no_batch_dim_case) {
+      (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
+    } else {
+      (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
+    }
+  }
+}
+
+// Useful for many NN operators.
+// The operator must satisfy the following:
+// - All arguments must accept an optional batch dim.
+// - All arguments must be the same rank
+#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction>());
+
+#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
+  m.impl(#op, \
+         torch::CppFunction::makeFromBoxedFunction<\
+             boxed_all_tensors_have_optional_bdim<\
+                 feature_rank, \
+                 contig_tensor_index>\
+             >());
+
+template 
+struct ExistingBdimBatchRuleHelper;
+
+template 
+struct ExistingBdimBatchRuleHelper> {
+  static std::tuple> apply(
+      const Tensor& self,
+      std::optional self_bdim,
+      T... extra_args) {
+    auto self_ = reshape_dim_into(*self_bdim, 0, self);
+    auto out = Func(self_, std::forward(extra_args)...);
+    return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
+  }
+};
+
+// USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
+// INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
+// It is important that this macro is not passed a function pointer!!
+#define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
+    ExistingBdimBatchRuleHelper<\
+      decltype(&fn),\
+      &fn,\
+      c10::guts::function_traits::parameter_types>::apply)
+
+
+#define EXISTING_BDIM(op) \
+  VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
+
+#define EXISTING_BDIM2(op, overload) \
+  VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
+
+#define INVOKE(object,ptrToMember)  ((object).*(ptrToMember))
+
+
+template 
+Tensor& unary_inplace_batch_rule(Tensor& self, std::optional /*unused*/, ExtraArgs... extra_args) {
+  INVOKE(self, Method)(std::forward(extra_args)...);
+  return self;
+}
+
+inline int64_t get_bdim_size4(
+    const Tensor& a_value, std::optional a_bdim,
+    const Tensor& b_value, std::optional b_bdim,
+    const Tensor& c_value, std::optional c_bdim,
+    const Tensor& d_value, std::optional d_bdim) {
+  if (a_bdim)
+    return a_value.size(*a_bdim);
+  if (b_bdim)
+    return b_value.size(*b_bdim);
+  if (c_bdim)
+    return c_value.size(*c_bdim);
+  if (d_bdim)
+    return d_value.size(*d_bdim);
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+inline int64_t get_bdim_size3(
+    const Tensor& a_value, std::optional a_bdim,
+    const Tensor& b_value, std::optional b_bdim,
+    const Tensor& c_value, std::optional c_bdim) {
+  if (a_bdim)
+    return a_value.size(*a_bdim);
+  if (b_bdim)
+    return b_value.size(*b_bdim);
+  if (c_bdim)
+    return c_value.size(*c_bdim);
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+inline int64_t get_bdim_size2(
+    const Tensor& a_value, std::optional a_bdim,
+    const Tensor& b_value, std::optional b_bdim) {
+  if (a_bdim)
+    return a_value.size(*a_bdim);
+  if (b_bdim)
+    return b_value.size(*b_bdim);
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+inline c10::SymInt get_bdim_size2_symint(
+    const Tensor& a_value, std::optional a_bdim,
+    const Tensor& b_value, std::optional b_bdim) {
+  if (a_bdim)
+    return a_value.sym_size(*a_bdim);
+  if (b_bdim)
+    return b_value.sym_size(*b_bdim);
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+// [start, start + 1, ..., stop - 1]
+inline VmapDimVector range(int64_t start, int64_t stop) {
+  TORCH_INTERNAL_ASSERT(stop >= start);
+  VmapDimVector dims;
+  dims.reserve(stop - start);
+  for (int64_t i = start; i < stop; i++) {
+    dims.emplace_back(i);
+  }
+  return dims;
+}
+std::tuple _binary_pointwise_helper(
+    const Tensor& tensor, std::optional tensor_batch_dim, const Tensor& other, std::optional other_batch_dim,
+    bool do_type_promotion=true);
+
+} // namespace at::functorch
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h
new file mode 100644
index 0000000000000000000000000000000000000000..7351ea6fb52283786dae5dd24efe6dcdd4bea698
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h
@@ -0,0 +1,131 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include 
+#include 
+
+// This file contains template metaprogramming things that are used for our
+// batching rules.
+//
+// See NOTE: [vmap plumbing] for more details on why this is necessary.
+// The plumbing has a bunch of metaprogramming hacks for determining the signature
+// of a batching rule from the signature of the operator, many of which use the
+// helper functions in this file.
+
+namespace at::functorch {
+
+// Metaprogramming things
+template  using typelist = c10::guts::typelist::typelist;
+template  using head_t = c10::guts::typelist::head_t;
+template  using concat_t = c10::guts::typelist::concat_t;
+template  class debug_t;
+
+// tail operation
+template
+struct tail final {
+    static_assert(c10::guts::false_t::value,
+                  "In typelist::tail, the T argument must be typelist<...>.");
+};
+template
+struct tail> final {
+  using type = typelist;
+};
+template using tail_t = typename tail::type;
+
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
+  using type = Next;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, std::optional, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, std::optional, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, std::optional, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, std::optional, Next, Tail> {
+  using type = Tail;
+};
+template  struct RemoveBatchDimAfterTensor {
+  using first = head_t;
+  using next = tail_t;
+  using second = head_t;
+  using tail = tail_t;
+
+  using type = concat_t<
+    typelist,
+    typename RemoveBatchDimAfterTensor<
+      typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext::type
+    >::type
+  >;
+};
+template  struct RemoveBatchDimAfterTensor> {
+  using type = typelist;
+};
+template <> struct RemoveBatchDimAfterTensor> {
+  using type = typelist<>;
+};
+template using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor::type;
+
+template  struct UnpackSingleItemTuple {
+  using type = T;
+};
+template  struct UnpackSingleItemTuple> {
+  using type = T;
+};
+template  using unpack_single_item_tuple_t = typename UnpackSingleItemTuple::type;
+
+template  struct BuildFunctionHelper;
+template  struct BuildFunctionHelper> {
+  using type = Return(Args...);
+};
+template 
+struct BuildFunction {
+  using type = typename BuildFunctionHelper>::type;
+};
+template  using build_function_t = typename BuildFunction::type;
+
+
+template  struct ToOperatorType {
+  using batch_rule_return_type = typename c10::guts::function_traits::return_type;
+  using batch_rule_parameter_types = typename c10::guts::function_traits::parameter_types;
+
+  using operator_parameter_types = remove_batch_dim_after_tensor_t;
+  using operator_return_type =
+    unpack_single_item_tuple_t<
+      c10::guts::typelist::to_tuple_t<
+        remove_batch_dim_after_tensor_t<
+          c10::guts::typelist::from_tuple_t>>>;
+
+  using type = build_function_t;
+};
+template  using to_operator_t = typename ToOperatorType::type;
+
+} // namespace at::functorch
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa7e4e2645bfa7ea7c63e3cc9ec99721769697f6
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h
@@ -0,0 +1,192 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include 
+#include 
+
+namespace at::functorch {
+
+// This files contains the legacy (now-deprecated) batching rule API.
+// Please try to use the new-style batching rule API (see writing_batch_rules.md)
+
+// This file contains abstractions used for transforming *logical* vmap arguments
+// into *physical* arguments. (Keep reading for definitions of these terms).
+
+// NOTE: [Logical vs physical args]
+// Consider the following vmap.
+//   vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
+// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
+// with batch dims 0 and 2:
+//   BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
+//
+// We say the *logical* view of the tensor has size [3] -- tensors inside
+// `func` appear to have size [3].
+// However, the *physical* underlying tensor (the one passed to vmap) has size
+// [2, 3, 4].
+//
+// This notion of logical vs physical also extends to non-tensor arguments.
+// Consider the previous tensor; let's assume the user called
+// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
+// dimension they are reducing over is dim 0 but the physical dim is dim 1
+// (the first non-batch dimension)
+
+// Forward declared; see NOTE: [What is a VmapPhysicalView?]
+struct VmapPhysicalView;
+
+// Most PyTorch operators take 4 or fewer inputs.
+constexpr int64_t kVmapTransformStaticInputSize = 4;
+using VmapPhysicalViewVec = SmallVector;
+
+// Pytorch generally advertises good performance for <= 5 dims.
+// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
+// dimensions to get 8. Adjust this number as necessary
+constexpr int64_t kVmapStaticDimVecSize = 8;
+using VmapDimVector = SmallVector;
+using VmapSymDimVector = SmallVector;
+
+// NOTE: [What is an VmapTransform?]
+// An *VmapTransform* converts logical views of tensors to physical views.
+//
+// Batching rules use VmapTransforms to convert logical arguments to
+// physical arguments, then call one or more at:: operator that handles the
+// physical arguments, and then converts the physical result back to a logical
+// argument.
+
+// VmapTransform for operators that take tensors with multiple batch dims.
+// Given one or more logical views on Tensors, `logicalToPhysical`
+// permutes all of the batch dims to the front of the tensor, aligns
+// and expands the batch dims to match each other (according to their `level`),
+// and returns a VmapPhysicalView on the tensor(s).
+struct TORCH_API MultiBatchVmapTransform {
+  static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
+  static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
+};
+
+// VmapTransform for operators that broadcast all inputs.
+// Given some logical views on Tensors, `logicalToPhysical`:
+// - permutes all of the batch dims to the front of the tensors
+// - aligns all the batch dims to the collective levels of all of the tensors.
+//   If a tensor does not have a batch dim for a vmap level, then it receives
+//   a size-one dimension for said level.
+// - aligns the non-batch dims to have the same dimensionality, adding extra
+//   size-1 dimensions in between the batch dimensions and the non-batch dimensions
+//   so that the batch dimensions are lined up from the right.
+//
+// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
+// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
+// of size (B, 1, 2) and (B, 3, 2).
+//
+// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
+// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
+// actually *need* to return a tensor of size (1, 2) for the second tensor
+// because the broadcasting operation takes care of that for us, but we do
+// it anyways to keep things simple.
+struct TORCH_API BroadcastingVmapTransform {
+  static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
+};
+
+// Forward declared, if you're reading this file head to toe, don't worry about
+// it yet.
+struct VmapPhysicalToLogicalMap;
+
+// NOTE: [What is a VmapPhysicalView?]
+// VmapPhysicalView represents a physical view on a Tensor.
+//
+// One can use it to further convert logical dimension indices, logical shapes,
+// and more to their physical variants, or convert a new (physical) tensor into
+// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
+//
+// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
+// the front and some levels that correspond to said batch dimensions.
+//
+// The levels bitset specifies which vmap levels correspond to the batch
+// dimensions at the front of the tensor. In particular, the number of set bits
+// corresponds to the number of batch dimensions on `tensor` and the rightmost
+// bit of `levels` specifies the maximum number of nested vmaps we are in at
+// this point in time.
+// For example, given:
+//   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
+//
+// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
+// than or equal to 3.
+//   bitset: 010100
+//              ^
+//              |
+//   levels: 012345
+struct TORCH_API VmapPhysicalView {
+  VmapPhysicalView(Tensor&& tensor, std::bitset levels)
+      : levels_(levels), tensor_(std::move(tensor)) {
+    // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
+  }
+
+  Tensor& tensor() { return tensor_; }
+  const Tensor& tensor() const { return tensor_; }
+
+  // Maps logical dim indices to physical dim indices. Also does dim wrapping.
+  //
+  // For example, given:
+  //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
+  //
+  // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
+  // This is because the size of levels tell us that the first two dimensions
+  // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
+  // a physical dim of `n + 2`.
+  VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
+  int64_t getPhysicalDim(int64_t logical_dim) const;
+
+  // Returns a VmapPhysicalToLogicalMap object. This can be used for
+  // mapping a physical tensor to a new logical tensor (BatchedTensor)
+  VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
+
+  // Maps a logical shape to a physical shape by prepending the batch
+  // sizes to the logical shape.
+  VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
+  SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const;
+
+  int64_t numBatchDims() const;
+
+ private:
+  int64_t numLogicalDims() const;
+
+  std::bitset levels_;
+  Tensor tensor_;
+};
+
+// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
+// to a logical one (BatchedTensor). It holds some levels that are used to do the
+// mapping and assumes that the batch dimensions in the physical tensor all
+// occur at the front of the tensor.
+struct TORCH_API VmapPhysicalToLogicalMap {
+  VmapPhysicalToLogicalMap(std::bitset levels): levels_(levels) {}
+
+  // Maps a physical tensor to a new logical tensor (BatchedTensor).
+  // Assumes that all of the "batch dimensions" are at the front
+  // of the physical tensor. For example, given:
+  // - x = rank-4 Tensor with size 2, 3, 5, 7
+  // - levels = (2, 4)
+  // Returns:
+  // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
+  Tensor apply(const Tensor& physical_tensor) const;
+
+  // Given a vector of physical tensors,
+  // 1. maps each tensor to a new logical tensor. Assumes that all of the
+  //    "batch dimensions" are at the front of the physical tensors.
+  // 2. stores the new logical tensors back into the passed-in vector. This is
+  //    to avoid additional dynamic allocations.
+  void applyInplace(std::vector& physical_tensors) const;
+
+  std::bitset levels_;
+};
+
+
+} // namespace at::functorch
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Macros.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Macros.h
new file mode 100644
index 0000000000000000000000000000000000000000..bd7386a3a3cd55c40a744d2c8700c0a6021a008c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Macros.h
@@ -0,0 +1,8 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#define SINGLE_ARG(...) __VA_ARGS__
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/VmapInterpreter.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/VmapInterpreter.h
new file mode 100644
index 0000000000000000000000000000000000000000..f1fc965730899b849d263d98d8faa0e891f777ae
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/VmapInterpreter.h
@@ -0,0 +1,30 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+
+namespace at::functorch {
+
+// This is the interpreter that handles the functionalize() transform.
+// See NOTE: [functorch interpreter stack] for more details.
+
+struct VmapInterpreterPtr {
+  explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
+  c10::SymInt batchSize() const {
+    return std::get(base_->meta()).batchSize_;
+  }
+  RandomnessType randomness() const {
+    return std::get(base_->meta()).randomness_;
+  }
+ private:
+  const Interpreter* base_;
+};
+
+} // namespace at::functorch
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/EmptyTensor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/EmptyTensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..3507c0e17afd44189a5a69e3bc216b10a61dd626
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/EmptyTensor.h
@@ -0,0 +1,33 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+#include 
+
+namespace at::detail {
+
+C10_EXPORT TensorBase empty_mps(
+    IntArrayRef size,
+    std::optional dtype_opt,
+    std::optional layout_opt,
+    std::optional device_opt,
+    std::optional pin_memory_opt,
+    std::optional memory_format_opt);
+C10_EXPORT TensorBase empty_mps(IntArrayRef size, const TensorOptions& options);
+
+C10_EXPORT TensorBase empty_strided_mps(
+    IntArrayRef size,
+    IntArrayRef stride,
+    ScalarType dtype,
+    std::optional device_opt);
+
+C10_EXPORT TensorBase empty_strided_mps(
+    IntArrayRef size,
+    IntArrayRef stride,
+    const TensorOptions& options);
+
+} // namespace at::detail
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/IndexKernels.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/IndexKernels.h
new file mode 100644
index 0000000000000000000000000000000000000000..be3ad0b5c05a9d98cf1ef253e0fe69b7816d4af1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/IndexKernels.h
@@ -0,0 +1,225 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+namespace at::mps {
+
+static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
+template
+Y cast(const X x);
+
+template<>
+{1} cast<{1}, {0}>(const {0} x) {{
+ return {2};
+}}
+
+kernel void scatter_kernel_n(uint linear_index          [[thread_position_in_grid]],
+                             constant void * src_       [[buffer(0)]],
+                             device void * dst_         [[buffer(1)]],
+                             constant uint32_t * size   [[buffer(2)]],
+                             constant uint32_t * stride [[buffer(3)]],
+                            constant uint32_t & numel   [[buffer(4)]],
+                            constant int32_t & ndim     [[buffer(5)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    uint64_t dst_offs = 0;
+    auto dst_idx = linear_index;
+    for(int dim = ndim - 1; dim >= 0; --dim) {{
+      dst_offs += stride[dim] * (dst_idx % size[dim]);
+      dst_idx /= size[dim];
+    }}
+
+    dst[dst_offs] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_4(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant packed_uint4 & size   [[buffer(2)]],
+                             constant packed_uint4 & stride [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint4 local_index;
+    local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
+    local_index.y = linear_index / (size[3] * size[2]) % size[1];
+    local_index.z = linear_index / size[3] % size[2];
+    local_index.w = linear_index % size[3];
+
+    const packed_uint4 strided_index = local_index * stride;
+    dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_3(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant packed_uint3 & size   [[buffer(2)]],
+                             constant packed_uint3 & stride [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint3 local_index;
+    local_index.x = linear_index / (size[2] * size[1]) % size[0];
+    local_index.y = linear_index / size[2] % size[1];
+    local_index.z = linear_index % size[2];
+
+    const packed_uint3 strided_index = local_index * stride;
+    dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_2(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant packed_uint2 & size   [[buffer(2)]],
+                             constant packed_uint2 & stride [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint2 local_index;
+    local_index.x = linear_index / size[1] % size[0];
+    local_index.y = linear_index % size[1];
+
+    const packed_uint2 strided_index = local_index * stride;
+    dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_1(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant int & size            [[buffer(2)]],
+                             constant int & stride          [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    const int local_index = linear_index % size;
+    const int strided_index = local_index * stride;
+    dst[strided_index] = cast<{1}>(src[linear_index]);
+}}
+)METAL_SCATTER";
+
+static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
+template
+Y cast(const X x);
+
+template<>
+{1} cast<{1}, {0}>(const {0} x) {{
+ return {2};
+}}
+
+kernel void gather_kernel_n(uint linear_index           [[thread_position_in_grid]],
+                            constant void * src_        [[buffer(0)]],
+                            device void * dst_          [[buffer(1)]],
+                            constant uint32_t * size    [[buffer(2)]],
+                            constant uint32_t * stride  [[buffer(3)]],
+                            constant uint32_t & numel   [[buffer(4)]],
+                            constant int32_t & ndim     [[buffer(5)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    uint64_t src_offs = 0;
+    auto src_idx = linear_index;
+    for(int dim = ndim - 1; dim >= 0; --dim) {{
+      src_offs += stride[dim] * (src_idx % size[dim]);
+      src_idx /= size[dim];
+    }}
+
+    dst[linear_index] = cast<{1}>(src[src_offs]);
+}}
+
+kernel void gather_kernel_4(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant packed_uint4 & size    [[buffer(2)]],
+                            constant packed_uint4 & stride  [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint4 local_index;
+    local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
+    local_index.y = linear_index / (size[3] * size[2]) % size[1];
+    local_index.z = linear_index / size[3] % size[2];
+    local_index.w = linear_index % size[3];
+
+    const packed_uint4 strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
+}}
+
+kernel void gather_kernel_3(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant packed_uint3 & size    [[buffer(2)]],
+                            constant packed_uint3 & stride  [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint3 local_index;
+    local_index.x = linear_index / (size[2] * size[1]) % size[0];
+    local_index.y = linear_index / size[2] % size[1];
+    local_index.z = linear_index % size[2];
+
+    const packed_uint3 strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
+}}
+
+kernel void gather_kernel_2(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant packed_uint2 & size    [[buffer(2)]],
+                            constant packed_uint2 & stride  [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint2 local_index;
+    local_index.x = linear_index / size[1] % size[0];
+    local_index.y = linear_index % size[1];
+
+    const packed_uint2 strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
+}}
+
+kernel void gather_kernel_1(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant int & size             [[buffer(2)]],
+                            constant int & stride           [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    const int local_index = linear_index % size;
+    const int strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index]);
+}}
+)METAL_GATHER";
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSAllocator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSAllocator.h
new file mode 100644
index 0000000000000000000000000000000000000000..f295d43e1e5acec8c55d03ba4e3a62478567e63d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSAllocator.h
@@ -0,0 +1,442 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// this implementation is based on CUDACachingAllocator.
+// It utilizes Metal Heaps to improve the performance with buffer allocation.
+// Do not include this header. Use MPSAllocatorInterface.h instead.
+// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
+namespace at::mps::HeapAllocator {
+
+static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
+static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
+static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
+static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
+static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
+static const size_t kXLargeHeapD =
+    MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
+static const size_t kXLargeHeapU =
+    MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
+static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
+
+// buffer pools could be customized with a combination of usage flags
+enum UsageFlags : uint32_t {
+  PRIVATE = 0,
+  SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
+  SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
+  MANAGED = (1 << 2), // managed storage mode
+  HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
+  SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
+};
+// debug verbosity flags
+enum DebugVerbosity : uint32_t {
+  SILENT = 0,
+  PROFILING = (1 << 0), // print generic profiling data for total system memory usage
+  ALLOCATIONS = (1 << 1), // print buffer allocations
+  RECYCLES = (1 << 2), // print buffer recycling
+  RELEASES = (1 << 3), // print buffer releases
+  LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
+};
+
+struct HeapBlock;
+
+struct BufferBlock {
+  id buffer;
+  void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
+  size_t size; // size after alignment
+  size_t requested_size; // requested size (before alignment)
+  // buffer shape is used for retrieving base of views in cached graphs
+  std::vector shape;
+  bool in_use = false;
+  HeapBlock* heap;
+  id_t buf_id;
+  // counter to candidate least recently used buffers for garbage collection
+  uint32_t gc_count = 0;
+  uint32_t use_count = 0;
+  // counter to assign unique ids to buffer blocks
+  static uint64_t buffer_counter;
+  // Metal events used to sync GPU/CPU operations on the shared-storage buffers
+  MPSEventPtr event;
+
+  BufferBlock(size_t Size, size_t RequestedSize = 0, const id Buffer = nullptr, HeapBlock* Heap = nullptr)
+      : buffer(Buffer), size(Size), requested_size(RequestedSize), heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) {}
+
+  static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
+    return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
+  }
+  static size_t alignUp(size_t Size, size_t Alignment) {
+    assert(((Alignment - 1) & Alignment) == 0);
+    return ((Size + Alignment - 1) & ~(Alignment - 1));
+  }
+  uint32_t retainCount() const {
+    return [buffer retainCount];
+  }
+};
+typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
+
+struct BufferPool;
+struct AllocParams {
+  AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool)
+      : search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) {}
+  size_t size() const {
+    return search_key.size;
+  }
+
+  BufferBlock search_key;
+  BufferPool* pool;
+  BufferBlock* buffer_block = nullptr;
+  size_t requested_size;
+  // true if we exceed the low watermark limit. In this case
+  // we apply strategies to relieve the pressure before allocation.
+  bool has_memory_pressure = false;
+  // true if we're allocating on a unified memory device
+  bool has_unified_memory = true;
+};
+
+struct HeapBlock {
+  id heap;
+  struct {
+    size_t total, available;
+  } size;
+  BufferPool* pool;
+  unsigned int n_buffers = 0;
+  id_t heap_id;
+  // indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
+  bool is_split;
+  // counter to assign unique ids to heap blocks
+  static uint64_t heap_counter;
+
+  HeapBlock(size_t Size, const id Heap = nullptr, BufferPool* Pool = nullptr)
+      : heap(Heap),
+        size({.total = Size, .available = Size}),
+        pool(Pool),
+        heap_id(Heap ? ++heap_counter : 0),
+        is_split(true) {}
+
+  static MTLResourceOptions getOptions(uint32_t usage) {
+    // TODO: check the caching performance of write-combined mode
+    MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
+
+    if (usage & UsageFlags::MANAGED)
+      options |= MTLResourceStorageModeManaged;
+    else if (usage & UsageFlags::SHARED)
+      options |= MTLResourceStorageModeShared;
+    else
+      options |= MTLResourceStorageModePrivate;
+
+    options |=
+        (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
+
+    return options;
+  }
+
+  static HeapBlock* createHeapBlock(AllocParams& params, id device, uint32_t usage) {
+    HeapBlock* heapBlock = nullptr;
+    bool is_split = true;
+    const size_t size = params.size();
+    MTLHeapDescriptor* d = [MTLHeapDescriptor new];
+    if (d) {
+      const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
+      if (size <= kMaxSmallAlloc) {
+        d.size = kSmallHeap;
+      } else if (size < kMinLargeAlloc) {
+        d.size = kLargeHeap;
+      } else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
+        d.size = kXLargeHeap;
+      } else {
+        d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
+        is_split = false;
+      }
+      d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
+      d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
+      // this automatically handles Metal buffer access synchronizations at the
+      // cost of slightly lower performance.
+      d.hazardTrackingMode =
+          (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
+      d.resourceOptions = getOptions(usage);
+      d.type = MTLHeapTypeAutomatic;
+      id heap = [device newHeapWithDescriptor:d];
+      if (heap) {
+        [heap setPurgeableState:MTLPurgeableStateNonVolatile];
+        const size_t heap_size = heapAvailableSize(heap);
+        heapBlock = new HeapBlock(heap_size, heap, params.pool);
+        if (heapBlock) {
+          heapBlock->is_split = is_split;
+        }
+      }
+      [d release];
+    }
+    return heapBlock;
+  }
+  static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
+    return (a->size.available != b->size.available) ? a->size.available < b->size.available
+                                                    : (uintptr_t)a->heap < (uintptr_t)b->heap;
+  }
+  static NSUInteger heapAvailableSize(id heap, size_t Alignment = vm_page_size) {
+    return [heap maxAvailableSizeWithAlignment:Alignment];
+  }
+  NSUInteger Size() {
+    return [heap size];
+  }
+  id newMTLBuffer(size_t length, uint32_t usage) {
+    id buf = [heap newBufferWithLength:length options:getOptions(usage)];
+    if (buf) {
+      updateAvailableSize();
+      n_buffers++;
+    }
+    return buf;
+  }
+  // returns the retainCount before releasing the buffer
+  uint32_t releaseMTLBuffer(id& buffer) {
+    const uint32_t retainCount = [buffer retainCount];
+    [buffer release];
+    buffer = nil;
+    updateAvailableSize();
+    n_buffers--;
+    return retainCount;
+  }
+  // returns the retainCount before releasing the heap
+  uint32_t releaseMTLHeap() {
+    const uint32_t retainCount = [heap retainCount];
+    TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
+    [heap setPurgeableState:MTLPurgeableStateEmpty];
+    [heap release];
+    heap = nil;
+    size.available = 0;
+    return retainCount;
+  }
+  uint32_t retainCount() const {
+    return [heap retainCount];
+  }
+  void updateAvailableSize() {
+    size.available = heapAvailableSize(heap);
+  }
+};
+typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
+
+struct BufferPool {
+  enum class Kind {
+    PRIVATE_SMALL,
+    PRIVATE_LARGE,
+    SHARED_SMALL,
+    SHARED_LARGE,
+    SCALAR,
+  };
+
+  BufferPool(const id Device, uint32_t Usage)
+      : device(Device), usage(Usage), heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) {}
+
+  const id device;
+  // usage flags to customize the pool for various purposes (see UsageFlags enum)
+  const uint32_t usage;
+  // total number of buffers in the pool
+  uint32_t n_buffers = 0;
+  // total allocations size on this pool
+  size_t allocated_size = 0;
+  // total memory available in the pool
+  size_t available_size = 0;
+  // list of heaps ordered by their "available" (not total) memory size
+  std::set heaps;
+  // list of only "available" buffers in the pool (i.e., buffers not in-use)
+  std::set available_buffers;
+  // list of buffers that are in a state of "limbo" where they've already been freed
+  // from PyTorch-side, but were not returned to pool due to still being
+  // in-use by command buffers with retainCount > 1. In this state, the buffer is
+  // neither ready to be recycled, nor could be returned to pool as available.
+  // These buffers will be returned to pool once the command buffer's
+  // completionHandler callbacks are called.
+  std::unordered_set buffers_pending_free;
+  // list of heaps pending size update
+  std::unordered_set heaps_pending_update;
+};
+
+class MPSHeapAllocatorImpl {
+ public:
+  explicit MPSHeapAllocatorImpl()
+      : m_device(at::mps::MPSDevice::getInstance()->device()),
+        m_max_buffer_size([m_device maxBufferLength]),
+        m_stream(getDefaultMPSStream()),
+        m_event_pool(getMPSEventPool()) {
+    init_allocator();
+  }
+  ~MPSHeapAllocatorImpl() {
+    emptyCache();
+  }
+  // interface exposed to at::Allocator
+  id malloc(size_t size, uint32_t usage);
+  // frees a buffer and returns it into buffer pool
+  void free(void* ptr);
+  // releases all the cached buffers and their associated heaps
+  void emptyCache();
+  // free inactive buffers that are pending to be freed
+  void freeInactiveBuffers();
+  // returns true if buffer was allocated from the shared pool
+  bool isSharedBuffer(const void* ptr);
+  // get the requested unaligned size of an MTLBuffer
+  ssize_t getUnalignedBufferSize(const void* ptr);
+  // set the shape of a base tensor from a view tensor
+  void setBufferShape(const void* ptr, const IntArrayRef& shape);
+  // retrieve the shape of a base tensor from a view tensor
+  IntArrayRef getBufferShape(const void* ptr);
+  // get the unique ID of the buffer
+  id_t getBufferId(const void* ptr);
+  // allocate a buffer from a specialized pool to import CPU scalars into GPU
+  id allocScalarBufferWithValue(void* value, size_t size);
+  // returns a CPU-mapping of the input buffer and its retainCount,
+  // if only it has Shared storage-mode and allocated on MPSAllocator
+  std::pair getSharedBufferPtr(const void* buffer);
+  // records events for a list of MTLBuffers (list is used to lock the mutex once)
+  // returns true if records any event (given if passed buffers exist and are shared-storage)
+  bool recordEvents(c10::ArrayRef buffers);
+  // waits for the event to signal the completion of GPU execution
+  // on the passed shared buffers (list is used to lock the mutex once)
+  // returns true if actually waited on any event
+  bool waitForEvents(c10::ArrayRef buffers);
+  // this indicates how far (in Megabytes) the current total allocations are from the
+  // low watermark limit which is used to detect if we're under memory pressure
+  // This returns zero if we've reached the low watermark limit
+  ssize_t getLowWatermarkValue();
+  // (see m_low_watermark_ratio for description)
+  void setLowWatermarkRatio(double ratio);
+  // (see m_high_watermark_ratio for description)
+  void setHighWatermarkRatio(double ratio);
+  // (see m_low_watermark_limit for description)
+  size_t getLowWatermarkLimit() const {
+    return m_low_watermark_limit;
+  }
+  // (see m_max_total_allowed_size for description)
+  size_t getHighWatermarkLimit() const {
+    return m_max_total_allowed_size;
+  }
+  // (see m_total_allocated_memory for description)
+  size_t getTotalAllocatedMemory() const {
+    return m_total_allocated_memory;
+  }
+  // (see m_current_allocated_memory for description)
+  size_t getCurrentAllocatedMemory() const {
+    return m_current_allocated_memory;
+  }
+  // total GPU memory allocated in the process by Metal driver; including
+  // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
+  size_t getDriverAllocatedMemory() const {
+    return current_allocated_size();
+  }
+  // recommended Max memory for Metal
+  size_t getRecommendedMaxMemory() const {
+    return max_device_size();
+  }
+  // (see enum DebugVerbosity for description)
+  uint32_t getDebugVerbosity() const {
+    return m_debug_verbosity;
+  }
+  // returns the device that we allocate from
+  inline id Device() const {
+    return m_device;
+  }
+
+  inline std::string format_size(uint64_t size) const;
+
+ private:
+  // (see m_high_watermark_ratio for description)
+  constexpr static double default_high_watermark_ratio = 1.7;
+  // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
+  constexpr static double default_high_watermark_upper_bound = 2.0;
+  // (see m_low_watermark_ratio for description)
+  // on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
+  constexpr static double default_low_watermark_ratio_unified = 1.4;
+  constexpr static double default_low_watermark_ratio_discrete = 1.0;
+
+  const id m_device;
+  std::recursive_mutex m_mutex;
+  // allocated buffers by device pointer
+  ska::flat_hash_map m_allocated_buffers;
+  // using a container for pools to simplify iterating them
+  ska::flat_hash_map> m_pools;
+  // total memory allocated by HeapAllocator (including blocks in pools)
+  size_t m_total_allocated_memory = 0;
+  // currently active memory allocations in use (i.e., blocks not in pools)
+  size_t m_current_allocated_memory = 0;
+  // max buffer size allowed by Metal
+  size_t m_max_buffer_size = 0;
+  // maximum total size allowed to be allocated
+  size_t m_max_total_allowed_size = 0;
+  // high watermark ratio is a hard limit for the total allowed allocations
+  // 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
+  // 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
+  // >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
+  // e.g., value 0.95 means we allocate up to 95% of recommended maximum
+  // allocation size; beyond that, the allocations would fail with OOM error.
+  double m_high_watermark_ratio;
+  // low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
+  // level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
+  // Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
+  // e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
+  // allocation size.
+  double m_low_watermark_ratio;
+  // low watermark size limit (in Bytes) at the time we initialize the allocator
+  size_t m_low_watermark_limit;
+  // use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
+  uint32_t m_debug_verbosity;
+  // default MPS stream
+  MPSStream* m_stream;
+  // we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
+  std::shared_ptr m_event_pool;
+
+  void init_allocator();
+  void init_buffer_pools();
+  HeapBlock* get_free_heap(AllocParams& params);
+  bool get_free_buffer(AllocParams& params);
+  BufferBlock* get_allocated_buffer_block(const void* ptr);
+  BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
+  bool alloc_buffer(AllocParams& params);
+  void free_buffer(BufferBlock* buffer_block);
+  // returns true if the container heap is also released
+  bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
+  void release_buffers(BufferPool& pool);
+  bool release_available_cached_buffers(AllocParams& params);
+  bool release_cached_buffers();
+  // free unused cached blocks to reclaim GPU memory if memory pressure is high
+  void garbage_collect_cached_buffers(AllocParams& params);
+  // returns the suitable buffer pool type for the usage or
+  // requested/allocated sizes
+  BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
+  // returns the aligned allocation size that is optimized
+  // for the buffers to get reused frequently
+  size_t get_allocation_size(size_t size, uint32_t usage) const;
+  // maximum size of device memory available for allocation in current process
+  // Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
+  size_t max_device_size() const {
+    return [m_device recommendedMaxWorkingSetSize];
+  }
+  // there are implicit allocations from MPS backend, so we need to query the 'device' for
+  // total allocated size instead of manually tracking in MPSAllocator
+  size_t current_allocated_size() const {
+    return [m_device currentAllocatedSize];
+  }
+
+  bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
+    for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
+      MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(
+          buffer_block ? buffer_block->buffer : nullptr, event);
+    }
+    return true;
+  }
+};
+
+} // namespace at::mps::HeapAllocator
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..cf8de460db3c678225e92d6733a6975a4df7a11c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h
@@ -0,0 +1,73 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2023 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+#define MB(x) (x * 1048576UL)
+
+namespace at::mps {
+
+// this is a public interface to access MPSAllocator.
+// Do not declare methods that would depend on MPS or Metal frameworks.
+class IMPSAllocator : public c10::Allocator {
+ public:
+  // see the comments in MPSAllocator.h for the description of these methods.
+  virtual void emptyCache() const = 0;
+  virtual void freeInactiveBuffers() const = 0;
+  virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
+  virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
+  virtual id_t getBufferId(const void* ptr) const = 0;
+  virtual void setBufferShape(const void* ptr, const IntArrayRef& shape)
+      const = 0;
+  virtual bool isSharedBuffer(const void* ptr) const = 0;
+  virtual bool isSharedStorageSupported() const = 0;
+  virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size)
+      const = 0;
+  virtual std::string formatSize(size_t size) const = 0;
+  virtual void setLowWatermarkRatio(double ratio) const = 0;
+  virtual void setHighWatermarkRatio(double ratio) const = 0;
+  virtual ssize_t getLowWatermarkValue() const = 0;
+  virtual size_t getLowWatermarkLimit() const = 0;
+  virtual size_t getHighWatermarkLimit() const = 0;
+  virtual size_t getTotalAllocatedMemory() const = 0;
+  virtual size_t getCurrentAllocatedMemory() const = 0;
+  virtual size_t getDriverAllocatedMemory() const = 0;
+  virtual size_t getRecommendedMaxMemory() const = 0;
+  virtual std::pair getSharedBufferPtr(
+      const void* ptr) const = 0;
+  virtual bool recordEvents(c10::ArrayRef buffers) const = 0;
+  virtual bool waitForEvents(c10::ArrayRef buffers) const = 0;
+};
+
+class IMpsAllocatorCallback {
+ public:
+  enum class EventType {
+    ALLOCATED, // buffer got allocated to be used immediately
+    RECYCLED, // buffer pulled from free list to be reused
+    FREED, // buffer put to free list for future recycling
+    RELEASED, // buffer memory released
+    ALLOCATION_FAILED // buffer allocation failed
+  };
+  virtual ~IMpsAllocatorCallback() = default;
+  virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
+};
+
+// MPS allocator will execute every registered callback when a block of memory
+// is freed.
+TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
+#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
+  C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__)
+
+IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
+
+bool isMPSPinnedPtr(const void* data);
+
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSDevice.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSDevice.h
new file mode 100644
index 0000000000000000000000000000000000000000..a6edee0a5332ab72e04e334af2c341009fd16817
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSDevice.h
@@ -0,0 +1,90 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __OBJC__
+#include 
+#include 
+typedef id MTLDevice_t;
+#else
+typedef void* MTLDevice_t;
+#endif
+
+namespace at::mps {
+
+// Helper enum to check if a MPSGraph op is supported in a given macOS version
+enum class MacOSVersion : uint32_t {
+  MACOS_VER_14_4_PLUS = 0,
+  MACOS_VER_15_0_PLUS,
+  MACOS_VER_15_1_PLUS,
+  MACOS_VER_15_2_PLUS,
+};
+
+//-----------------------------------------------------------------
+//  MPSDevice
+//
+// MPSDevice is a singleton class that returns the default device
+//-----------------------------------------------------------------
+
+class TORCH_API MPSDevice {
+ public:
+  /**
+   * MPSDevice should not be cloneable.
+   */
+  MPSDevice(MPSDevice& other) = delete;
+  /**
+   * MPSDevice should not be assignable.
+   */
+  void operator=(const MPSDevice&) = delete;
+  /**
+   * Gets single instance of the Device.
+   */
+  static MPSDevice* getInstance();
+  /**
+   * Returns the single device.
+   */
+  MTLDevice_t device() {
+    return _mtl_device;
+  }
+  /**
+   * Returns whether running on Ventura or newer
+   */
+  bool isMacOS13Plus(MacOSVersion version) const;
+
+  /**
+   * Returns device name
+   */
+  std::string getName() const;
+
+  /**
+   * Returns number of GPU cores.
+   * 1 Core = 16 ExecutionUnit x 8 ALU x 24 threads
+   */
+  unsigned getCoreCount() const;
+
+  ~MPSDevice();
+
+ private:
+  static MPSDevice* _device;
+  MTLDevice_t _mtl_device;
+  MPSDevice();
+};
+
+TORCH_API bool is_available();
+TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
+TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
+
+inline Device getDeviceFromPtr(void* ptr) {
+  return {c10::DeviceType::MPS, 0};
+}
+
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSEvent.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSEvent.h
new file mode 100644
index 0000000000000000000000000000000000000000..aee5c72b4f2b6aaca9544e553e38b460144ea6ba
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSEvent.h
@@ -0,0 +1,110 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2023 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at::mps {
+
+// NOTE: don't create instances of this class directly.
+// Use MPSEventPool to acquire instances of MPSEvent.
+class MPSEvent {
+ public:
+  explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
+  ~MPSEvent();
+
+  // records an event on the stream
+  void record(bool needsLock, bool syncEvent = false);
+  // makes all future work submitted to the stream wait for this event.
+  bool wait(bool needsLock, bool syncEvent = false);
+  // schedules a notifyListener callback for the event.
+  bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
+  // checks if events are already signaled.
+  bool query() const;
+  // blocks the CPU thread until all the GPU work that were scheduled
+  // prior to recording this event are completed.
+  bool synchronize();
+  // resets this event with new parameters in case it gets reused from the event
+  // pool
+  void reset(MPSStream* stream, bool enable_timing);
+  // returns the unique ID of the event instance
+  id_t getID() const {
+    return m_id;
+  }
+  // returns the completion timestamp of the event
+  uint64_t getCompletionTime() const {
+    return m_completion_time;
+  }
+  // if already recorded, waits for cpu_sync_cv to be signaled
+  void waitForCpuSync();
+
+ private:
+  id_t m_id;
+  // enables measuring the completion time of the notifyListener of this event
+  bool m_enable_timing;
+  uint64_t m_signalCounter = 0;
+  MPSStream* m_stream = nullptr;
+  MTLSharedEvent_t m_event = nullptr;
+  MTLSharedEventListener* m_listener = nullptr;
+  // used to sync the events created on this Stream with CPU
+  std::mutex m_cpu_sync_mutex{};
+  std::condition_variable m_cpu_sync_cv{};
+  // CondVar predicate to sync the events created on this Stream with CPU
+  bool m_cpu_sync_completed = false;
+  // used to compute elapsed time
+  uint64_t m_completion_time = 0;
+
+  void recordLocked(bool syncEvent);
+  bool waitLocked(bool syncEvent);
+  bool notifyLocked(MTLSharedEventNotificationBlock block);
+  void notifyCpuSync();
+  static uint64_t getTime() {
+    return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
+  }
+};
+
+typedef std::unique_ptr> MPSEventPtr;
+
+class MPSEventPool {
+ public:
+  explicit MPSEventPool(MPSStream* default_stream);
+  ~MPSEventPool();
+
+  MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
+  void emptyCache();
+
+  // these are mainly used for MPSHooks and torch.mps.Event() bindings
+  id_t acquireEvent(bool enable_timing);
+  void releaseEvent(id_t event_id);
+  void recordEvent(id_t event_id, bool syncEvent);
+  void waitForEvent(id_t event_id, bool syncEvent);
+  void synchronizeEvent(id_t event_id);
+  bool queryEvent(id_t event_id);
+  // returns elapsed time between two recorded events in milliseconds
+  double elapsedTime(id_t start_event_id, id_t end_event_id);
+
+ private:
+  MPSStream* m_default_stream = nullptr;
+  std::recursive_mutex m_mutex;
+  std::stack> m_pool{};
+  // dictionary to associate event IDs with event objects
+  // used to retain in-use events out of the pool
+  // for torch.mps.Event() bindings.
+  std::unordered_map m_in_use_events{};
+  uint64_t m_event_counter = 0;
+  std::function m_default_deleter;
+
+  MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
+};
+
+// shared_ptr is used to get MPSEventPool destroyed after dependent instances
+std::shared_ptr getMPSEventPool();
+
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..01a6d8bc876ab27357ea9c7007d4add08b8d87be
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h
@@ -0,0 +1,66 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace mps::detail {
+
+constexpr uint32_t PHILOX_STATE_N = 7;
+struct rng_data_pod {
+  std::array state{1};
+  uint64_t seed = default_rng_seed_val;
+};
+
+TORCH_API const Generator& getDefaultMPSGenerator();
+TORCH_API Generator
+createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
+
+} // namespace mps::detail
+
+struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
+  // Constructors
+  MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
+  ~MPSGeneratorImpl() override = default;
+
+  // MPSGeneratorImpl methods
+  std::shared_ptr clone() const;
+  void set_current_seed(uint64_t seed) override;
+  void set_offset(uint64_t offset) override;
+  uint64_t get_offset() const override;
+  uint64_t current_seed() const override;
+  uint64_t seed() override;
+  void set_state(const c10::TensorImpl& new_state) override;
+  c10::intrusive_ptr get_state() const override;
+  void update_philox_counters();
+
+  void set_engine(at::Philox4_32 engine) {
+    engine_ = engine;
+  }
+  at::Philox4_32 engine() {
+    return engine_;
+  }
+  uint32_t* state_data() {
+    return data_.state.data();
+  }
+  static DeviceType device_type() {
+    return DeviceType::MPS;
+  }
+
+ private:
+  mps::detail::rng_data_pod data_;
+  at::Philox4_32 engine_;
+
+  MPSGeneratorImpl* clone_impl() const override;
+};
+
+} // namespace at
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSGuardImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSGuardImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..86b357aeac416c8d001142ed8bd8afb7c7aff8d9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSGuardImpl.h
@@ -0,0 +1,187 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __OBJC__
+#include 
+#include 
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::mps {
+
+typedef MPSEvent* mpsEvent_t;
+
+// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
+// https://github.com/pytorch/pytorch/issues/77170
+struct TORCH_API MPSGuardImpl final
+    : public c10::impl::DeviceGuardImplInterface {
+  static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
+
+  // constructor
+  MPSGuardImpl() {}
+  explicit MPSGuardImpl(c10::DeviceType t) {
+    TORCH_CHECK(
+        t == DeviceType::MPS,
+        "MPSGuardImpl initialized with non-MPS DeviceType: ",
+        t);
+  }
+
+  // returns the type
+  c10::DeviceType type() const override {
+    return c10::DeviceType::MPS;
+  }
+
+  Device exchangeDevice(Device d) const override {
+    return Device(c10::DeviceType::MPS, 0);
+  }
+
+  Device getDevice() const override {
+    return Device(c10::DeviceType::MPS, 0);
+  }
+
+  std::optional uncheckedGetDevice() const noexcept {
+    return Device(c10::DeviceType::MPS, 0);
+  }
+
+  void setDevice(Device d) const override {
+    TORCH_CHECK(d.is_mps(), "Expected a MPS device, but got ", d);
+  }
+
+  void uncheckedSetDevice(Device d) const noexcept override {
+    // TODO: Currently setting only device 0
+  }
+
+  Stream getStream(Device d) const override {
+    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
+  }
+
+  Stream getNewStream(Device, int priority = 0) const override {
+    (void)priority;
+    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
+  }
+
+  Stream getDefaultStream(Device d) const override {
+    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
+  }
+
+  // NB: These do NOT set the current device
+  Stream exchangeStream(Stream s) const override {
+    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
+  }
+  DeviceIndex deviceCount() const noexcept override {
+    if (at::hasMPS()) {
+      // TODO: extend it for multi-device case
+      return 1;
+    } else {
+      return 0;
+    }
+  }
+
+  // Event-related functions
+  void createEvent(mpsEvent_t* event, const EventFlag flag) const;
+
+  void destroyEvent(void* event, const DeviceIndex device_index)
+      const noexcept override;
+
+  void record(
+      void** event,
+      const Stream& stream,
+      const DeviceIndex device_index,
+      const EventFlag flag) const override;
+
+  void block(void* event, const Stream& stream) const override;
+
+  bool queryEvent(void* event) const override;
+
+  void synchronizeEvent(void* event) const override;
+
+  double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
+      const override;
+
+  void synchronizeDevice(const DeviceIndex device_index) const override;
+};
+
+/// A variant of OptionalDeviceGuard that is specialized for MPS.
+struct OptionalMPSGuard {
+  explicit OptionalMPSGuard() : guard_() {}
+
+  explicit OptionalMPSGuard(std::optional device_opt)
+      : guard_(device_opt) {}
+
+  /// Set the current MPS device to the passed device index, if it is not
+  /// nullopt
+  explicit OptionalMPSGuard(std::optional device_index_opt)
+      : guard_(device_index_opt) {}
+
+  // Copy is not allowed
+  OptionalMPSGuard(const OptionalMPSGuard&) = delete;
+  OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
+  OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
+  OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
+
+  /// Sets the MPS device to the given device, initializing the guard if it
+  /// is not already initialized.  Errors if the given device is not a MPS
+  /// device.
+  void set_device(Device device) {
+    guard_.set_device(device);
+  }
+
+  /// Sets the MPS device to the given device, initializing the guard if it is
+  /// not already initialized.  Errors if the given device is not a MPS device.
+  void reset_device(Device device) {
+    guard_.reset_device(device);
+  }
+
+  /// Sets the MPS device to the given device index, initializing the guard if
+  /// it is not already initialized.
+  void set_index(DeviceIndex device_index) {
+    guard_.set_index(device_index);
+  }
+
+  /// Returns the device that was set immediately prior to initialization of the
+  /// guard, or nullopt if the guard is uninitialized.
+  std::optional original_device() const {
+    return guard_.original_device();
+  }
+
+  /// Returns the most recent device that was set using this device guard,
+  /// either from construction, or via set_device, if the guard is initialized,
+  /// or nullopt if the guard is uninitialized.
+  std::optional current_device() const {
+    return guard_.current_device();
+  }
+
+  /// Restore the original MPS device, resetting this guard to uninitialized
+  /// state.
+  void reset() {
+    guard_.reset();
+  }
+
+ private:
+  c10::impl::InlineOptionalDeviceGuard guard_;
+};
+
+C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl)
+
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSHooks.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSHooks.h
new file mode 100644
index 0000000000000000000000000000000000000000..5f743b3e1dff0644d09405dfde7bae7d9fd82f7d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSHooks.h
@@ -0,0 +1,76 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::mps {
+
+// The real implementation of MPSHooksInterface
+struct MPSHooks : public at::MPSHooksInterface {
+  MPSHooks(at::MPSHooksArgs) {}
+  void init() const override;
+
+  // MPSDevice interface
+  bool hasMPS() const override;
+  bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
+
+  Device getDeviceFromPtr(void* data) const override;
+
+  // MPSGeneratorImpl interface
+  const Generator& getDefaultGenerator(
+      DeviceIndex device_index = -1) const override;
+  Generator getNewGenerator(DeviceIndex device_index = -1) const override;
+
+  // MPSStream interface
+  void deviceSynchronize() const override;
+  void commitStream() const override;
+  void* getCommandBuffer() const override;
+  void* getDispatchQueue() const override;
+
+  // MPSAllocator interface
+  Allocator* getMPSDeviceAllocator() const override;
+  void emptyCache() const override;
+  size_t getCurrentAllocatedMemory() const override;
+  size_t getDriverAllocatedMemory() const override;
+  size_t getRecommendedMaxMemory() const override;
+  void setMemoryFraction(double ratio) const override;
+  bool isPinnedPtr(const void* data) const override;
+  Allocator* getPinnedMemoryAllocator() const override;
+
+  // MPSProfiler interface
+  void profilerStartTrace(const std::string& mode, bool waitUntilCompleted)
+      const override;
+  void profilerStopTrace() const override;
+
+  // MPSEvent interface
+  uint32_t acquireEvent(bool enable_timing) const override;
+  void releaseEvent(uint32_t event_id) const override;
+  void recordEvent(uint32_t event_id) const override;
+  void waitForEvent(uint32_t event_id) const override;
+  void synchronizeEvent(uint32_t event_id) const override;
+  bool queryEvent(uint32_t event_id) const override;
+  double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
+      const override;
+
+  bool isBuilt() const override {
+    return true;
+  }
+  bool isAvailable() const override {
+    return hasMPS();
+  }
+  bool hasPrimaryContext(DeviceIndex device_index) const override {
+    // When MPS is available, it is always in use for the one device.
+    return true;
+  }
+};
+
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSProfiler.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSProfiler.h
new file mode 100644
index 0000000000000000000000000000000000000000..0c97168d8ce36a2062f40c228694843604ba35a4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSProfiler.h
@@ -0,0 +1,472 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef __OBJC__
+typedef void* MTLCaptureManager;
+#endif
+
+namespace at::mps {
+
+namespace Profiler {
+
+struct BaseInfo {
+  // profiling info types
+  enum class Type {
+    GRAPH,
+    KERNEL,
+    COPY,
+    CPU_FALLBACK,
+  };
+
+  BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle)
+      : type(infoType), profileId(Id), handle(Handle) {}
+  virtual ~BaseInfo() = default;
+
+  // type of profiling info
+  Type type;
+  // unique profile ID for execution instances of operations or copies
+  uint64_t profileId;
+  // ID generated by os_signpost
+  // since it's possible to use event and interval-based signposts at the
+  // same time, we need separate IDs for each.
+  os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
+  // accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime -
+  // GPUStartTime")
+  std::atomic totalGpuTime{0.0};
+  // accumulated Scheduling time in ms (obtained from CompletionHandler's
+  // "KernelEndTime - KernelStartTime")
+  std::atomic totalSchedulingTime{0.0};
+  // indicates if the operation or copy execution has completed
+  std::atomic_bool completed{false};
+  // handle used to identify the profile info's instance (usually the pointer)
+  const uintptr_t handle;
+
+  virtual const std::string toString(
+      double gpuTime = 0,
+      double schedulingTime = 0) const;
+  // builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
+  static std::string buildTensorString(
+      const Tensor& tensor,
+      bool includeBufferId = false);
+  static uint64_t getTime() {
+    return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
+  }
+};
+
+struct OperationInfo : BaseInfo {
+  OperationInfo(
+      const void* Handle,
+      bool IsGraph,
+      uint64_t Id,
+      const std::string& StrKey)
+      : BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)),
+        strKey(StrKey) {}
+
+  uint64_t runCount = 0;
+  std::string strKey;
+
+  const std::string toString(double gpuTime = 0, double schedulingTime = 0)
+      const override;
+
+  // builds a string for a kernel
+  static std::string buildKernelString(
+      const std::string& kernelName,
+      const TensorList& tensors,
+      bool includeBufferId = false) {
+    std::stringstream kernelStr;
+    kernelStr << kernelName;
+    for (const Tensor& tensor : tensors) {
+      kernelStr << ':' << BaseInfo::buildTensorString(tensor, includeBufferId);
+    }
+    return kernelStr.str();
+  }
+};
+
+struct CpuFbInfo : BaseInfo {
+  CpuFbInfo(uint64_t Id, const std::string& OpName)
+      : BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {}
+
+  uint64_t runCount = 0;
+  // the current and total overhead of copies in bytes required to convert the
+  // Op's input tensors from MPS to CPU and then output from CPU back to MPS
+  size_t currentCopyOverhead = 0;
+  size_t totalCopyOverhead = 0;
+  std::string opName;
+  std::string strKey;
+  uint64_t startTime = 0;
+
+  const std::string toString(double gpuTime = 0, double schedulingTime = 0)
+      const override;
+
+  void updateCopyOverhead(const TensorList& tensors) {
+    currentCopyOverhead = 0;
+    for (const Tensor& tensor : tensors) {
+      if (tensor.defined()) {
+        currentCopyOverhead += tensor.nbytes();
+      }
+    }
+    totalCopyOverhead += currentCopyOverhead;
+  }
+};
+
+struct CopyInfo : BaseInfo {
+  enum class Kind {
+    MPS_TO_MPS,
+    MPS_TO_CPU,
+    CPU_TO_MPS,
+  };
+
+  CopyInfo(
+      const void* Handle,
+      size_t Length,
+      uint64_t Id,
+      bool IsNonBlocking,
+      bool UsesBlitter)
+      : BaseInfo(Type::COPY, Id, uintptr_t(Handle)),
+        kind(Kind::MPS_TO_MPS),
+        length(Length),
+        isNonBlocking(IsNonBlocking),
+        usesBlitter(UsesBlitter) {}
+
+  Kind kind;
+  size_t length;
+  bool isNonBlocking;
+  bool usesBlitter;
+  std::string srcStrKey;
+  std::string dstStrKey;
+  // for copies that don't use blitters, we measure CPU time
+  uint64_t startTime = 0;
+
+  const std::string toString(double gpuTime = 0, double schedulingTime = 0)
+      const override;
+
+  static std::string buildTensorString(
+      const void* buffer,
+      const OptionalTensorRef tensor,
+      bool includeBufferId = false);
+
+  static bool isStorageOnMPS(
+      const void* buffer,
+      const OptionalTensorRef tensor) {
+    if (tensor.has_value()) {
+      return tensor->device().type() == at::kMPS;
+    }
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
+    // getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
+    return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
+  }
+
+  static Kind getCopyKind(
+      const void* srcBuffer,
+      const void* dstBuffer,
+      const OptionalTensorRef srcTensor,
+      const OptionalTensorRef dstTensor) {
+    const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
+    const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
+    if (isSrcOnMPS && !isDstOnMPS) {
+      return Kind::MPS_TO_CPU;
+    } else if (!isSrcOnMPS && isDstOnMPS) {
+      return Kind::CPU_TO_MPS;
+    }
+    return Kind::MPS_TO_MPS;
+  }
+};
+
+struct CopyStat : CopyInfo {
+  explicit CopyStat(std::string CopyKindStr)
+      : CopyInfo(nullptr, 0, 0, false, false),
+        kindStr(std::move(CopyKindStr)) {}
+  // total number of copies
+  size_t totalCount = 0;
+  // number of Scalar copies (i.e., less than sizeof(int64))
+  size_t scalarsCount = 0;
+  // number of blocking copies (i.e., require syncing to GPU)
+  size_t blockingCount = 0;
+  // number of copies that used memcpy(), instead of Metal Blit Encoder
+  size_t memcpyCount = 0;
+  // accumulated GPU time in ms for the scalar copies
+  std::atomic scalarsGpuTime{0.0};
+  // copy kind in string type
+  std::string kindStr;
+};
+
+class MPSProfiler {
+ public:
+  // lower 16 bits used for profiler options
+  enum ProfileOptions : uint32_t {
+    OPTIONS_NONE = 0,
+    // ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK,
+    // etc.) (used for convenience to not compute bit flags by OR-ing manually)
+    // trace all signpost types using events
+    ALL_SIGNPOST_EVENTS = (1 << 0),
+    // trace all signpost types using intervals
+    ALL_SIGNPOST_INTERVALS = (1 << 1),
+    // always wait for command buffer to finish executing after each commit
+    WAIT_UNTIL_COMPLETED = (1 << 2),
+    // for interval-based signposts, include the scheduling portion of
+    // Graph/Kernel/Copy executions as well.
+    // if flag is disable, only "GPU run time" is included in interval,
+    // and not schedule time.
+    INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
+
+    // use these if you need to trace signposts types individually (rarely
+    // required) trace signpost using intervals
+    USE_INTERVALS = (1 << 4),
+    // trace signpost by emitting events
+    USE_EVENTS = (1 << 5),
+    // used for sanity check (Change this when new option added)
+    OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
+  };
+
+  // when adding new types, #define the type string in MPSProfiler.mm as well.
+  // upper 16 bits used for event types
+  enum SignpostTypes : uint32_t {
+    SIGNPOST_NONE = 0,
+    // trace signposts for PyTorch operation executions
+    RUN_OPERATION = (1 << 16),
+    // trace signposts for blitter copies
+    BLIT_COPY = (1 << 17),
+    // trace signposts for ops that fall back on CPU
+    CPU_FALLBACK = (1 << 18),
+    // used for sanity check (Change this when new type added)
+    SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
+  };
+
+  enum LogOptions : uint32_t {
+    LOG_NONE = 0,
+
+    // Info logging options during execution
+    // -------------------------------------
+    // prints operation info (id/key/run_count) during execution
+    OPERATION_INFO = (1 << 0),
+    // prints copy info (src/dst tensors/buffers, size, etc.) during execution
+    COPY_INFO = (1 << 1),
+    // prints CPU Fallback info (id/runCount/opName/copyOverhead) during
+    // execution
+    CPU_FALLBACK_INFO = (1 << 2),
+
+    // Profiling Statistics logging options when process terminates
+    // ------------------------------------------------------------
+    // prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before
+    // process terminates this is convenient to not combine following stats bit
+    // flags manually
+    ALL_STATS = (1 << 3),
+    // prints operation stats (GPU times, run count, etc.) before process
+    // terminates
+    OPERATION_STATS = (1 << 4),
+    // prints copies stats (GPU times, copy kinds, sizes, etc.) before process
+    // terminates
+    COPY_STATS = (1 << 5),
+    // prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
+    // for tensors, etc.) before process terminates
+    CPU_FALLBACK_STATS = (1 << 6),
+
+    // Metadata format options when logging the info
+    // ---------------------------------------------
+    // if enabled, includes GPU run time in metadata (i.e.,
+    // GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324
+    // ms])
+    INCLUDE_GPU_TIME = (1 << 7),
+    // if enabled, includes GPU scheduling time in metadata separately
+    // (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
+    // e.g., [GPU=0.324 ms, KRNL=0.036 ms]
+    INCLUDE_KERNEL_TIME = (1 << 8),
+    // if enabled, includes the unique buffer ID in metadata for the storage
+    // of a tensor that was allocated on MPSAllocator. This is useful (along
+    // with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are
+    // involved with various operations.
+    INCLUDE_BUFFER_ID = (1 << 9),
+
+    // used for sanity check (Change this when new option added)
+    LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
+  };
+
+  explicit MPSProfiler();
+  ~MPSProfiler();
+
+  // the handle is either "MPSGraph*" or "id" for Metal
+  // Kernels the beginProfile*() functions return a profileId which is unique
+  // per graph/kernel/copy
+  uint64_t beginProfileKernel(
+      const void* handle,
+      const std::string& strKey,
+      bool isGraph);
+  uint64_t beginProfileKernel(
+      const void* handle,
+      const std::string& kernelName,
+      const TensorList& tensors);
+  uint64_t beginProfileCopy(
+      const void* srcBuffer,
+      const void* dstBuffer,
+      const OptionalTensorRef srcTensor,
+      const OptionalTensorRef dstTensor,
+      size_t length,
+      bool isNonBlocking,
+      bool usesBlitter = true);
+  uint64_t beginProfileCPUFallback(
+      const std::string& opName,
+      const TensorList& tensors);
+  void beginProfileGPUInterval(const void* handle);
+
+  void endProfileCopy(uint64_t profileId, SyncType syncType);
+  void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
+  void endProfileCPUFallback(const std::string& opName);
+
+  // these are used to hook into Python bindings for torch.mps.profiler module.
+  // this enables generating OS Signpost traces from MPSProfiler on-demand
+  // during runtime (instead of environment variables).
+  // The "mode" could be either "interval", "event", or both "interval,event"
+  // for interval-based and/or event-based signpost tracing.
+  void StartTrace(const std::string& mode, bool waitUntilCompleted);
+  void StopTrace();
+
+  // Abstractions for GPU trace capturing
+  bool isCaptureEnabled() const;
+  bool isCapturing() const;
+  void startCapture(const std::string& name, MPSStream* stream = nullptr);
+  void stopCapture(MPSStream* stream = nullptr);
+
+  // convenience functions to indicate whether signpost tracing or
+  // logging are enabled for the SignpostTypes
+  bool isOperationProfilingEnabled() const {
+    return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
+        (m_log_options &
+         (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
+  }
+  bool isCopyProfilingEnabled() const {
+    return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
+        (m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
+  }
+  bool isCPUFallbackProfilingEnabled() const {
+    return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
+        (m_log_options &
+         (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
+  }
+  bool isSignpostTracingEnabled() const {
+    return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
+  }
+
+ private:
+  // indicates what type of signpost types are enabled and traced by MPS
+  // profiler.
+  uint32_t m_signpost_types = 0;
+  uint32_t m_profile_options = 0;
+  uint32_t m_log_options = 0;
+  uint64_t m_kernel_counter = 0;
+  uint64_t m_graph_counter = 0;
+  uint64_t m_cpu_fb_counter = 0;
+  uint64_t m_copy_counter = 0;
+  // technically, it's possible to trace both events and intervals at the same
+  // time so we use separate os_log categories for them
+  os_log_t m_os_log_events;
+  os_log_t m_os_log_intervals;
+  // stats logging could run either from destructor or signal handler
+  // so this is used to check if logging has already started.
+  std::atomic_bool hasLoggedStats{false};
+  // indicates there are pending completionHandler callbacks that haven't been
+  // called yet.
+  std::atomic_bool hasPendingCompletionHandlers{false};
+  // used to capture sigint signal to log profiling stats
+  static struct sigaction currentSigint, previousSigint;
+
+  // We use the following lists for two reasons:
+  // 1- for interval-based signposts the "begin" point won't be in same function
+  // as the "end" point where we need to be able to retrieve signpost's info
+  // 2- if Operations info need to be logged when process ends using
+  // LogOptions::OPERATION_INFO.
+
+  // the pointer key for this map is either "MPSGraph*" or
+  // "id" for Metal Kernels this list is retained and
+  // could be logged along with aggregate profiling numbers when the process
+  // ends.
+  std::unordered_map>
+      m_op_info_list{};
+  // the string key for this map is the op name that we fall back to execute on
+  // CPU this list is retained and could be logged along with aggregate
+  // profiling numbers when the process ends.
+  std::unordered_map>
+      m_cpu_fb_info_list{};
+  // this list contains the info for copies, and its key is the unique profileId
+  // which is generated from m_copy_counter
+  // The copyInfo list is not retained.
+  std::unordered_map> m_copy_info_list{};
+  // a short list that contains copy stats
+  std::unordered_map>
+      m_copy_stat_list{};
+
+  mutable MTLCaptureManager* captureManager = nil;
+  unsigned captureCount = 0;
+
+  void initialize();
+  void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
+  void endProfileExecution(
+      BaseInfo& info,
+      os_signpost_id_t event_signpost_id,
+      os_signpost_id_t interval_signpost_id,
+      double gpuTime,
+      double schedulingTime);
+  void addProfilerScheduledHandler(BaseInfo& info);
+  void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
+  void emitSignpostEvent(
+      SignpostTypes signpost_type,
+      os_signpost_id_t signpost_id,
+      const std::string& msg) const;
+  void beginSignpostInterval(
+      SignpostTypes signpost_type,
+      os_signpost_id_t signpost_id,
+      const std::string& msg) const;
+  void endSignpostInterval(
+      SignpostTypes signpost_type,
+      os_signpost_id_t signpost_id) const;
+
+  void updateCopyStats(
+      const CopyInfo& copyInfo,
+      double gpuTime,
+      double schedulingTime);
+  // returns true if logging the profiling info "during the execution" is
+  // enabled
+  bool isProfileInfoLoggingEnabled(
+      BaseInfo::Type infoType,
+      bool isExecutionEnded);
+  // logs all the profiling stats that are enabled
+  void logProfilingStats();
+  // logs kernel profiling stats when the process ends.
+  void logOperationsProfilingStats(std::FILE* f) const;
+  // logs CPU Fallback profiling stats when the process ends.
+  void logCPUFallbackProfilingStats(std::FILE* f) const;
+  // logs copy profiling stats when the process ends.
+  void logCopyProfilingStats(std::FILE* f) const;
+
+  os_signpost_id_t generateSignpostId(
+      os_signpost_type_t signpostType,
+      const void* ptr = nullptr);
+  static SignpostTypes getSignpostType(BaseInfo::Type infoType);
+  static void handleIntSignal(int signal);
+};
+
+} // namespace Profiler
+
+Profiler::MPSProfiler& getMPSProfiler();
+
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSStream.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSStream.h
new file mode 100644
index 0000000000000000000000000000000000000000..e721649f42a607942a06630b7b9b9b0970049174
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/mps/MPSStream.h
@@ -0,0 +1,171 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __OBJC__
+#include 
+#include 
+#include 
+#include 
+typedef MPSCommandBuffer* MPSCommandBuffer_t;
+typedef id MTLCommandQueue_t;
+typedef id MTLComputeCommandEncoder_t;
+typedef id MTLSharedEvent_t;
+typedef id MTLDevice_t;
+typedef id MTLBuffer_t;
+#else
+#include 
+typedef void* MPSCommandBuffer_t;
+typedef void* MPSGraph;
+typedef void* MPSGraphExecutionDescriptor;
+typedef void* MPSGraphCompilationDescriptor;
+typedef void* MTLCommandQueue_t;
+typedef void* MTLComputeCommandEncoder_t;
+typedef void* MTLSharedEvent_t;
+typedef void* MTLDevice_t;
+typedef void* MTLBuffer_t;
+typedef void* MTLCommandBufferHandler;
+typedef void* NSDictionary;
+#define nil NULL
+#endif
+
+namespace at::mps {
+
+//-----------------------------------------------------------------
+//  MPSStream
+//-----------------------------------------------------------------
+
+enum class SyncType {
+  NONE, // no commit to command buffer
+  COMMIT, // commit and flush the command buffer
+  COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
+  COMMIT_AND_CONTINUE, // commit and continue with a new underlying command buffer
+  COMMIT_ADAPTIVE, // commit adaptively based on available memory
+};
+
+class TORCH_API MPSStream {
+ public:
+  enum Unchecked { UNCHECKED };
+
+  /// Construct a MPSStream from a Stream.  This construction is checked,
+  /// and will raise an error if the Stream is not, in fact, a MPS stream.
+  explicit MPSStream(Stream stream);
+
+  ~MPSStream();
+
+  MTLCommandQueue_t commandQueue() const {
+    return _commandQueue;
+  }
+
+  dispatch_queue_t queue() const {
+    return _serialQueue;
+  }
+
+  MPSCommandBuffer_t commandBuffer();
+  MTLComputeCommandEncoder_t commandEncoder();
+  void endKernelCoalescing();
+  void synchronize(SyncType syncType);
+  void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
+  void copy(MTLBuffer_t srcBuffer,
+            MTLBuffer_t dstBuffer,
+            size_t length,
+            size_t srcOffset,
+            size_t dstOffset,
+            uint64_t profileId,
+            SyncType syncType = SyncType::NONE);
+  void copy_and_sync(MTLBuffer_t srcBuffer,
+                     MTLBuffer_t dstBuffer,
+                     size_t length,
+                     size_t srcOffset,
+                     size_t dstOffset,
+                     bool non_blocking,
+                     uint64_t profileId);
+  void executeMPSGraph(MPSGraph* mpsGraph,
+                       NSDictionary* feeds,
+                       NSDictionary* results,
+                       SyncType syncType = SyncType::NONE);
+  void addCompletedHandler(MTLCommandBufferHandler block);
+
+  /// Get the MPS device index that this stream is associated with.
+  c10::DeviceIndex device_index() const {
+    return _stream.device_index();
+  }
+
+  MTLCommandQueue_t stream() const {
+    return _commandQueue;
+  }
+
+  MTLDevice_t device() const;
+
+  /// Explicit conversion to Stream.
+  Stream unwrap() const {
+    return _stream;
+  }
+
+  MTLBuffer_t getErrorBuffer();
+  void checkLastError();
+
+ private:
+  Stream _stream;
+  MTLCommandQueue_t _commandQueue = nil;
+  MPSCommandBuffer_t _commandBuffer = nil;
+  MPSCommandBuffer_t _prevCommandBuffer = nil;
+  MTLComputeCommandEncoder_t _commandEncoder = nil;
+  MPSGraphExecutionDescriptor* _executionDescriptor = nil;
+  MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
+  dispatch_queue_t _serialQueue = nullptr;
+  // CommitAndContinue is enabled by default
+  bool _enableCommitAndContinue = true;
+  // Buffer that contains last raised error
+  MTLBuffer_t _errorBuffer = nil;
+
+  // use synchronize() to access any of these commit functions outside MPSStream
+  void commit();
+  void commitAndWait();
+  void commitAndContinue();
+  void flush();
+};
+
+/**
+ * Get the current MPS stream
+ */
+TORCH_API MPSStream* getCurrentMPSStream();
+
+/**
+ * Get the default MPS stream
+ */
+TORCH_API MPSStream* getDefaultMPSStream();
+
+//-----------------------------------------------------------------
+//  MPSStreamImpl
+//-----------------------------------------------------------------
+
+class TORCH_API MPSStreamImpl {
+ public:
+  /**
+   * Gets single instance of the MPSStream.
+   */
+  static MPSStream* getInstance();
+
+ private:
+  static MPSStream* _stream;
+  MPSStreamImpl();
+};
+
+#ifdef __OBJC__
+void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
+#endif
+} // namespace at::mps
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Activation.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Activation.h
new file mode 100644
index 0000000000000000000000000000000000000000..3d91338d6cdd4c3387e4e97eb8a724d60cca3834
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Activation.h
@@ -0,0 +1,78 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+struct TensorIterator;
+struct TensorIteratorBase;
+class TensorBase;
+}
+
+namespace at::native {
+
+using structured_activation_fn = void (*)(TensorIteratorBase&);
+using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
+
+using activation_fn = void (*)(TensorIterator&);
+using activation_backward_fn = void (*)(TensorIterator&);
+using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
+using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
+using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
+using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
+using hardsigmoid_fn = void(*)(TensorIteratorBase&);
+using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
+using hardswish_fn = void(*)(TensorIterator&);
+using hardswish_backward_fn = void(*)(TensorIterator&);
+using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
+using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
+using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
+using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
+using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);
+using glu_jvp_fn = void (*)(TensorIteratorBase&);
+
+DECLARE_DISPATCH(elu_fn, elu_stub)
+DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub)
+DECLARE_DISPATCH(softplus_fn, softplus_stub)
+DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub)
+DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub)
+DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub)
+DECLARE_DISPATCH(threshold_fn, threshold_stub)
+DECLARE_DISPATCH(gelu_fn, GeluKernel)
+DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel)
+DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub)
+DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub)
+DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub)
+DECLARE_DISPATCH(hardswish_fn, hardswish_stub)
+DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub)
+DECLARE_DISPATCH(shrink_fn, hardshrink_stub)
+DECLARE_DISPATCH(softshrink_fn, softshrink_stub)
+DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub)
+DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub)
+DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub)
+DECLARE_DISPATCH(structured_activation_fn, glu_stub)
+DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub)
+DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub)
+DECLARE_DISPATCH(structured_activation_fn, silu_stub)
+DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub)
+DECLARE_DISPATCH(structured_activation_fn, mish_stub)
+DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub)
+DECLARE_DISPATCH(activation_fn, prelu_stub)
+DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BinaryOps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BinaryOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..d1632f1d978b92db9e0d73e9ebcc460c0afa7c15
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BinaryOps.h
@@ -0,0 +1,124 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+
+namespace at {
+struct TensorIterator;
+struct TensorIteratorBase;
+}
+
+namespace at::native {
+
+inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
+  TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
+              "Boolean alpha only supported for Boolean results.");
+  TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
+              || alpha.isIntegral(true),
+              "For integral input tensors, argument alpha must not be a floating point number.");
+  TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
+              "For non-complex input tensors, argument alpha must not be a complex number.")
+}
+
+// Basic checking for all sub functions.
+inline void sub_check(const TensorBase& self, const TensorBase& other) {
+  TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
+              "Subtraction, the `-` operator, with two bool tensors is not supported. "
+              "Use the `^` or `logical_xor()` operator instead.")
+  TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
+              "Subtraction, the `-` operator, with a bool tensor is not supported. "
+              "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
+}
+
+inline void sub_check(const TensorBase& self, const Scalar& scalar) {
+  TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
+              "Subtraction, the `-` operator, with two bool tensors is not supported. "
+              "Use the `^` or `logical_xor()` operator instead.")
+  TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
+              "Subtraction, the `-` operator, with a bool tensor is not supported. "
+              "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
+}
+
+using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
+using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
+using structured_binary_fn = void(*)(TensorIteratorBase&);
+
+using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
+using binary_fn_double = void(*)(TensorIterator&, double);
+using binary_fn = void(*)(TensorIterator&);
+using binary_clamp_fn_alpha =
+    void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
+
+// NB: codegenned
+DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub)
+
+DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub)
+DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub)
+DECLARE_DISPATCH(structured_binary_fn, mul_stub)
+DECLARE_DISPATCH(structured_binary_fn, div_true_stub)
+DECLARE_DISPATCH(structured_binary_fn, div_floor_stub)
+DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub)
+DECLARE_DISPATCH(structured_binary_fn, atan2_stub)
+DECLARE_DISPATCH(structured_binary_fn, remainder_stub)
+DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub)
+DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub)
+DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub)
+DECLARE_DISPATCH(structured_binary_fn, lshift_stub)
+DECLARE_DISPATCH(structured_binary_fn, rshift_stub)
+DECLARE_DISPATCH(binary_fn, logical_xor_stub)
+DECLARE_DISPATCH(binary_fn, logical_and_stub)
+DECLARE_DISPATCH(binary_fn, logical_or_stub)
+DECLARE_DISPATCH(structured_binary_fn, lt_stub)
+DECLARE_DISPATCH(structured_binary_fn, le_stub)
+DECLARE_DISPATCH(structured_binary_fn, gt_stub)
+DECLARE_DISPATCH(structured_binary_fn, ge_stub)
+DECLARE_DISPATCH(structured_binary_fn, eq_stub)
+DECLARE_DISPATCH(structured_binary_fn, ne_stub)
+DECLARE_DISPATCH(binary_fn, max_elementwise_stub)
+DECLARE_DISPATCH(binary_fn, min_elementwise_stub)
+DECLARE_DISPATCH(structured_binary_fn, maximum_stub)
+DECLARE_DISPATCH(structured_binary_fn, minimum_stub)
+DECLARE_DISPATCH(structured_binary_fn, fmax_stub)
+DECLARE_DISPATCH(structured_binary_fn, fmin_stub)
+DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub)
+DECLARE_DISPATCH(binary_fn_double, huber_stub)
+DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub)
+DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub)
+DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub)
+DECLARE_DISPATCH(structured_binary_fn, mse_stub)
+DECLARE_DISPATCH(structured_binary_fn, fmod_stub)
+DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub)
+DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub)
+DECLARE_DISPATCH(structured_binary_fn, gcd_stub)
+DECLARE_DISPATCH(structured_binary_fn, lcm_stub)
+DECLARE_DISPATCH(structured_binary_fn, hypot_stub)
+DECLARE_DISPATCH(structured_binary_fn, igamma_stub)
+DECLARE_DISPATCH(structured_binary_fn, igammac_stub)
+DECLARE_DISPATCH(structured_binary_fn, nextafter_stub)
+DECLARE_DISPATCH(structured_binary_fn, heaviside_stub)
+DECLARE_DISPATCH(structured_binary_fn, copysign_stub)
+DECLARE_DISPATCH(structured_binary_fn, xlogy_stub)
+DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub)
+DECLARE_DISPATCH(structured_binary_fn, zeta_stub)
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub)
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub)
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub)
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub)
+DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub)
+DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub)
+DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub)
+DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub)
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub)
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub)
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub)
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BucketizationUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BucketizationUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..dcf766d6e007a78e4be6de84e4ac693467604b34
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BucketizationUtils.h
@@ -0,0 +1,178 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+
+// original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
+// the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
+// match, will change them to be a common super type so comparisons are done between the same types.
+// For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
+// corresponding raw_* version should be used since it was already contiguous of the right type.
+inline void searchsorted_maybe_trim_input_tensors(
+    Tensor& trimmed_input,
+    Tensor& trimmed_boundaries,
+    Tensor& trimmed_sorter,
+    const Tensor& raw_input,
+    const Tensor& raw_boundaries,
+    const Tensor& raw_sorter) {
+  bool in_is_contiguous = raw_input.is_contiguous();
+  bool bd_is_contiguous = raw_boundaries.is_contiguous();
+  bool sort_is_contiguous = raw_sorter.is_contiguous();
+
+  if (!in_is_contiguous) {
+    TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
+      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
+      "tensor if possible. This message will only appear once per program.");
+    trimmed_input = raw_input.contiguous();
+  }
+  if (!bd_is_contiguous) {
+    TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
+      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
+      "tensor if possible. This message will only appear once per program.");
+    trimmed_boundaries = raw_boundaries.contiguous();
+  }
+  if (!sort_is_contiguous) {
+    TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
+      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
+      "tensor if possible. This message will only appear once per program.");
+    trimmed_sorter = raw_sorter.contiguous();
+  }
+  if (raw_input.dtype() != raw_boundaries.dtype()) {
+    at::native::ResultTypeState state = {};
+    state = at::native::update_result_type_state(raw_boundaries, state);
+    state = at::native::update_result_type_state(raw_input, state);
+    ScalarType common_stype = at::native::result_type(state);
+
+    TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
+    if (common_stype != raw_input.scalar_type()) {
+      trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
+    }
+    if (common_stype != raw_boundaries.scalar_type()) {
+      trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
+    }
+  }
+}
+
+/* unused but needed for internal jagged tensor class */
+inline void searchsorted_maybe_trim_input_tensors(
+    Tensor& trimmed_input,
+    Tensor& trimmed_boundaries,
+    const Tensor& raw_input,
+    const Tensor& raw_boundaries) {
+  Tensor trimmed_sorter;
+  Tensor raw_sorter;
+  searchsorted_maybe_trim_input_tensors(
+      trimmed_input,
+      trimmed_boundaries,
+      trimmed_sorter,
+      raw_input,
+      raw_boundaries,
+      raw_sorter);
+}
+
+inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
+  if (boundaries.dim() != input.dim()) {
+    return false;
+  }
+  const auto& dims_bd = boundaries.sizes();
+  const auto& dims_in = input.sizes();
+  for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
+    if (dims_bd[dim] != dims_in[dim]) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
+  auto tensor = c10::scalar_to_tensor(scalar, device);
+  // This is to adopt the scalar promotion rules defined in native/TypeProperties.h
+  // So we have the same type promotion rules as binary operations.
+  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
+  return tensor;
+}
+
+inline void searchsorted_pre_check(
+    const Tensor& boundaries,
+    const Tensor& input,
+    const Tensor& output,
+    const bool out_int32,
+    const bool right,
+    const std::optional side_opt,
+    const Tensor& sorter) {
+  if (side_opt) {
+    const std::string_view side = *side_opt;
+    TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
+      "got ", side);
+
+    // assume the user has not explicitly set (right=False, side="right")
+    TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
+    "of ", side, " while right was True");
+  }
+
+  TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
+    "should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
+    "tensor device type ", input.device());
+
+  if (sorter.defined()) {
+    TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
+      "have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
+      "device type ", boundaries.device());
+
+    TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
+      "size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
+
+    TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
+      "dtype but got dtype ", sorter.scalar_type());
+
+    if (sorter.numel() > 0) {
+      auto minmax = sorter.aminmax();
+      int64_t vmin = std::get<0>(minmax).item().toLong();
+      int64_t vmax = std::get<1>(minmax).item().toLong();
+      TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
+    }
+  }
+
+  TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
+    "torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
+    "boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
+    input.numel(), ")");
+
+  TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
+    "got 0 dimension");
+
+  TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
+    "torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
+    "and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
+    input.sizes());
+
+  ScalarType output_dtype = output.scalar_type();
+  TORCH_CHECK(
+      (output_dtype == ScalarType::Long && !out_int32) ||
+          (output_dtype == ScalarType::Int && out_int32),
+      "torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
+      "whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
+      " and out_int32 flag is ", (out_int32 ? "True" : "False"));
+
+  if (out_int32) {
+    TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
+      "torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
+      boundaries.sizes().back());
+  }
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CPUFallback.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CPUFallback.h
new file mode 100644
index 0000000000000000000000000000000000000000..daad5e2fdad114d3a45076e682d5f9c0417db9f3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CPUFallback.h
@@ -0,0 +1,51 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+// This function implements a boxed fallback to CPU.
+// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
+TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
+                            c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
+
+// This is a helper function that backends can use to directly call their boxed CPU fallback
+// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
+template
+struct _call_fallback_fn final {};
+
+template
+struct _call_fallback_fn final {
+    static ReturnType call(typename c10::maybe_keep_symint::type... args) {
+        auto op = c10::Dispatcher::singleton()
+            // TODO: figure out how to make compiler happy without dynamic casts
+            .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
+            //.findSchemaOrThrow("a", "b")
+            .typed::type...)>();
+        return c10::impl::BoxedKernelWrapper::type...)>::call(
+            c10::BoxedKernel::makeFromFunction(),
+            op,
+            c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
+            // TODO: get std::forward<> to work
+            args...
+            );
+    }
+};
+
+template
+using call_fallback_fn_symint = _call_fallback_fn;
+
+template
+using call_fallback_fn = _call_fallback_fn;
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h
new file mode 100644
index 0000000000000000000000000000000000000000..51401ef05a7ea39d3f7969424f50fe838652dd35
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h
@@ -0,0 +1,268 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#include 
+
+#pragma once
+
+namespace at::native {
+
+namespace {
+
+// operator_brackets_proxy is used in
+// CompositeRandomAccessor in place of operator[].
+// For some iterators, references returned by operator[]
+// could become invalid, operator_brackets_proxy tries to
+// resolve that by making accessor[n] to be equivalent to
+// *(accessor + n).
+template 
+class operator_brackets_proxy {
+  using reference = typename std::iterator_traits::reference;
+  using value_type = typename std::iterator_traits::value_type;
+
+public:
+  C10_HOST_DEVICE
+  operator_brackets_proxy(Accessor const& accessor)
+    : accessor(accessor)
+  {}
+
+  C10_HOST_DEVICE
+  operator reference() {
+    return *accessor;
+  }
+
+  C10_HOST_DEVICE
+  reference operator*() {
+    return *accessor;
+  }
+
+  C10_HOST_DEVICE
+  operator_brackets_proxy& operator=(value_type const& val) {
+    *accessor = val;
+    return *this;
+  }
+
+private:
+  Accessor accessor;
+};
+
+}
+
+// references_holder is used as a surrogate for the
+// references type from std::iterator_traits in CompositeRandomAccessor.
+// It is assumed in CompositeRandomAccessor that
+// References = tuple,
+// Values = tuple by default,
+// but they could be anything as long as References could be
+// cast to Values.
+// If you plan to use it with STL, for example, you will need to
+// define 'swap` and `get`(aka std::get) methods.
+template 
+class references_holder {
+public:
+  using values = Values;
+  using references = References;
+
+  C10_HOST_DEVICE
+  references_holder(references refs)
+    : refs{std::move(refs)}
+  {}
+
+  C10_HOST_DEVICE
+  operator references() {
+    return refs;
+  }
+
+  C10_HOST_DEVICE
+  operator values() {
+    return refs;
+  }
+
+  C10_HOST_DEVICE
+  references_holder& operator=(values vals) {
+    refs = vals;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  references& data() {
+    return refs;
+  }
+
+protected:
+  references refs;
+};
+
+// CompositeRandomAccessor is essentially a simplified version of
+// a random access iterator over two random access iterators.
+// TupleInfo should contain a variadic type `tuple`, and a method `tie`,
+// which constructs a tuple of references from a variadic list of arguments.
+template 
+class CompositeRandomAccessor {
+  using self_type = CompositeRandomAccessor;
+
+  using key_accessor_value_type =
+    typename std::iterator_traits::value_type;
+  using value_accessor_value_type =
+    typename std::iterator_traits::value_type;
+  using key_accessor_reference_type =
+    typename std::iterator_traits::reference;
+  using value_accessor_reference_type =
+    typename std::iterator_traits::reference;
+
+  using composite_value_type = typename TupleInfo::template tuple<
+    key_accessor_value_type,
+    value_accessor_value_type>;
+  using composite_reference = typename TupleInfo::template tuple<
+    key_accessor_reference_type,
+    value_accessor_reference_type>;
+
+public:
+  using value_type = composite_value_type;
+  using reference = references_holder;
+  // Note that CompositeRandomAccessor does not hold key and values
+  // in a specific datastructure, which means that a pointer to a (key, value)
+  // is not defined. Hence we just use a pointer type of the KeyAccessor.
+  using pointer = typename std::iterator_traits::pointer;
+  using difference_type = typename std::iterator_traits::difference_type;
+  using iterator_category = std::random_access_iterator_tag;
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor() = default;
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
+    : keys(keys), values(values)
+  {}
+
+  // Pointer-like operations {
+  C10_HOST_DEVICE
+  reference operator*() const {
+    return TupleInfo::tie(*keys, *values);
+  }
+
+  // operator->() is supposed to return a pointer type.
+  // Since CompositeRandomAccessor does not hold pointers to pairs,
+  // we just return a pointer to a key.
+  C10_HOST_DEVICE
+  auto* operator->() const {
+    return keys.operator->();
+  }
+
+  C10_HOST_DEVICE
+  reference operator[](difference_type idx) {
+    return operator_brackets_proxy(
+      CompositeRandomAccessor(keys + idx, values + idx)
+    );
+  }
+  // }
+
+  // Prefix/postfix increment/decrement {
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator++() {
+    ++keys;
+    ++values;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator++(int) {
+    CompositeRandomAccessor copy(*this);
+    ++*this;
+    return copy;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator--() {
+    --keys;
+    --values;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator--(int) {
+    CompositeRandomAccessor copy(*this);
+    --*this;
+    return copy;
+  }
+  // }
+
+  // Arithmetic operations {
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator+=(difference_type offset) {
+    keys += offset;
+    values += offset;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator+(difference_type offset) const {
+    return CompositeRandomAccessor(keys + offset, values + offset);
+  }
+
+  C10_HOST_DEVICE
+  friend CompositeRandomAccessor operator+(
+    difference_type offset,
+    const CompositeRandomAccessor& accessor
+  ) {
+    return accessor + offset;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator-=(difference_type offset) {
+    keys -= offset;
+    values -= offset;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator-(difference_type offset) const {
+    return CompositeRandomAccessor(keys - offset, values - offset);
+  }
+
+  C10_HOST_DEVICE
+  difference_type operator-(const CompositeRandomAccessor& other) const {
+    return keys - other.keys;
+  }
+  // }
+
+  // Comparison operators {
+  C10_HOST_DEVICE
+  bool operator==(const CompositeRandomAccessor& other) const {
+    return keys == other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator!=(const CompositeRandomAccessor& other) const {
+    return keys != other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator<(const CompositeRandomAccessor& other) const {
+    return keys < other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator<=(const CompositeRandomAccessor& other) const {
+    return keys <= other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator>(const CompositeRandomAccessor& other) const {
+    return keys > other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator>=(const CompositeRandomAccessor& other) const {
+    return keys >= other.keys;
+  }
+  // }
+
+protected:
+  KeyAccessor keys;
+  ValueAccessor values;
+};
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ConvUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ConvUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..a5ced6932024060b664136205e8317ac4b930ef0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ConvUtils.h
@@ -0,0 +1,480 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at::native {
+
+using conv_depthwise2d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub)
+using conv_depthwise3d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub)
+using cudnn_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, bool, bool, bool, std::array);
+DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub)
+using mps_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, std::array);
+DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub)
+using cudnn_convolution_transpose_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array);
+DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub)
+using miopen_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, bool, bool, std::array);
+DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub)
+using miopen_convolution_transpose_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array);
+DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub)
+using miopen_depthwise_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, bool, bool, std::array);
+DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub)
+using mkldnn_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, std::array);
+DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub)
+using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional&,
+    IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
+DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub)
+using mkldnn_convolution_transpose_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, int64_t, std::array);
+DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub)
+using slow_conv_dilated2d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub)
+using slow_conv_dilated3d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub)
+using slow_conv_transpose2d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub)
+using slow_conv_transpose3d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub)
+
+namespace {
+  bool is_cudnnv8_heuristic_mode_b() {
+    static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
+    return is_cudnnv8_heuristic_mode_b;
+  }
+}
+
+inline bool cudnnv8_enabled_check_debug() {
+  static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
+  static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
+  static uint8_t cudnnv8_debugcount = 0;
+  if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
+    TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b());
+    cudnnv8_debugcount++;
+  }
+  return cudnnv8_flag == 1;
+}
+
+inline bool cudnnv8_use_heur_mode_b() {
+  return is_cudnnv8_heuristic_mode_b();
+}
+
+// Keep in sync with py::enum_ in Module.cpp
+enum class ConvBackend {
+  CudaDepthwise2d,
+  CudaDepthwise3d,
+  Cudnn,
+  CudnnTranspose,
+  Empty,
+  Miopen,
+  MiopenDepthwise,
+  MiopenTranspose,
+  Mkldnn,
+  MkldnnTranspose,
+  MkldnnEmpty,
+  NnpackSpatial,
+  Overrideable,
+  Slow2d,
+  Slow3d,
+  SlowDilated2d,
+  SlowDilated3d,
+  SlowTranspose2d,
+  SlowTranspose3d,
+  Winograd3x3Depthwise,
+  Xnnpack2d,
+  Mps,
+  MpsTranspose,
+};
+
+// Overload for selecting the convolution backend from the full set of convolution inputs.
+// This overload is exposed to python for testing, etc.
+TORCH_API ConvBackend select_conv_backend(
+    const Tensor& input, const Tensor& weight, const std::optional& bias_opt,
+    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
+    bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
+
+TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
+    const Tensor& weight,
+    const ConvBackend backend);
+
+// ---------------------------------------------------------------------
+//
+// Math
+//
+// ---------------------------------------------------------------------
+
+constexpr int input_batch_size_dim = 0;  // also grad_input
+constexpr int input_channels_dim = 1;
+constexpr int output_batch_size_dim = 0;  // also grad_output
+constexpr int output_channels_dim = 1;
+constexpr int weight_output_channels_dim = 0;
+constexpr int weight_input_channels_dim = 1;
+
+// Often written as 2 + max_dim (extra dims for batch size and channels)
+constexpr int max_dim = 3;
+
+// ---------------------------------------------------------------------
+//
+// Checking
+//
+// ---------------------------------------------------------------------
+
+// Used on pad, stride and dilation
+static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
+{
+  TORCH_CHECK(args.size() <= expected_size,
+           "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
+           expected_size, " (while checking arguments for ", c, ")");
+  TORCH_CHECK(args.size() >= expected_size,
+           "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
+           expected_size, " (while checking arguments for ", c, ")");
+
+  auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
+  if (num_negative_values > 0){
+    std::stringstream ss;
+    ss << arg_name << " should be greater than zero but got (";
+    std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", "));
+    ss << args.back() <<  ")" << " (while checking arguments for " << c << ')';
+    TORCH_CHECK(false, ss.str());
+  }
+}
+
+
+// NOTE [ Convolution checks ]
+//
+// NB: For many call sites, it is not strictly necessary to check all of
+// these relationships (for example, for forward convolution, we compute
+// the size of output ourselves, so we don't actually need to check
+// output.  However, writing a single function that does everything
+// means we get to reuse it for both forwards and all backwards
+// variants, even when the set of "real" inputs varies.  The magic of
+// relational computing!
+//
+// (There is one downside, which is that it is slightly harder to write
+// error messages which are able to distinguish between real inputs
+// (which the user can change) and computed inputs (which the user can
+// only indirectly affect).  It would be an interesting exercise to
+// come up with a general framework to handle such situations.)
+inline void convolution_shape_check(
+    CheckedFrom c,
+    const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
+    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
+{
+  check_args(c, padding, input->dim() - 2, "padding");
+  check_args(c, stride, padding.size(), "stride");
+  check_args(c, dilation, padding.size(), "dilation");
+
+  // Input
+  checkDimRange(c, input, 3, 6 /* exclusive */);
+  checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
+
+  // Weight
+  checkSameDim(c, input, weight);
+
+  // TODO: check that output->size() matches output_sizes
+  // TODO: check that weight matches output->sizes()
+  checkSameDim(c, input, output);
+}
+
+// NB: conv_output_size and conv_input_size are not bijections,
+// as conv_output_size loses information; this is why conv_input_size
+// takes an extra output_padding argument to resolve the ambiguity.
+
+template 
+inline std::vector _conv_output_size(
+    ArrayRef input_size, ArrayRef weight_size,
+    ArrayRef padding, ArrayRef stride, ArrayRef dilation = ArrayRef()
+) {
+  // ASSERT(input_size.size() > 2)
+  // ASSERT(input_size.size() == weight_size.size())
+  bool has_dilation = !dilation.empty();
+  auto dim = input_size.size();
+  std::vector output_size(dim);
+  output_size[0] = input_size[input_batch_size_dim];
+  output_size[1] = weight_size[weight_output_channels_dim];
+  for (const auto d : c10::irange(2, dim)) {
+    auto dilation_ = has_dilation ? dilation[d - 2] : 1;
+    auto kernel = dilation_ * (weight_size[d] - 1) + 1;
+    output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
+  }
+  return output_size;
+}
+
+inline std::vector conv_output_size(
+    IntArrayRef input_size, IntArrayRef weight_size,
+    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
+) {
+  return _conv_output_size(input_size, weight_size, padding, stride, dilation);
+}
+
+inline std::vector conv_output_size(
+    SymIntArrayRef input_size, SymIntArrayRef weight_size,
+    SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
+) {
+  return _conv_output_size(input_size, weight_size, padding, stride, dilation);
+}
+
+template 
+std::vector _conv_input_size(
+    ArrayRef output_size, ArrayRef weight_size,
+    ArrayRef padding, ArrayRef output_padding, ArrayRef stride, ArrayRef dilation, T groups
+) {
+  // ASSERT(output_size.size() > 2)
+  // ASSERT(output_size.size() == weight_size.size())
+  auto dim = output_size.size();
+  std::vector input_size(dim);
+  input_size[0] = output_size[output_batch_size_dim];
+  input_size[1] = weight_size[weight_input_channels_dim] * groups;
+  for (const auto d : c10::irange(2, dim)) {
+    auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
+    input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
+                     kernel + output_padding[d - 2];
+  }
+  return input_size;
+}
+
+inline std::vector conv_input_size(
+    SymIntArrayRef output_size, SymIntArrayRef weight_size,
+    SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
+) {
+  return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups));
+}
+
+inline std::vector conv_input_size(
+    IntArrayRef output_size, IntArrayRef weight_size,
+    IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
+}
+
+template 
+std::vector _conv_weight_size(
+    ArrayRef input_size, ArrayRef output_size,
+    ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  auto dim = input_size.size();
+  std::vector weight_size(dim);
+  weight_size[0] = output_size[1];
+  weight_size[1] = input_size[1] / groups;
+  for (const auto d : c10::irange(2, dim)) {
+    auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
+               + padding[d - 2] * 2 - output_padding[d - 2];
+    weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
+  }
+  return weight_size;
+}
+
+inline std::vector conv_weight_size(
+    SymIntArrayRef input_size, SymIntArrayRef output_size,
+    SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
+}
+
+inline std::vector conv_weight_size(
+    IntArrayRef input_size, IntArrayRef output_size,
+    IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
+}
+
+inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
+  std::vector shape(dim, 1);
+  shape[1] = -1;
+  return bias.reshape(shape);
+}
+
+inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
+  // disable NHWC for float64 input.
+  if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
+      input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return at::MemoryFormat::Contiguous;
+  }
+  long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+  auto weight_ndim = weight.ndimension();
+
+  bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
+    (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+    (weight_memory_format == at::MemoryFormat::ChannelsLast)
+  );
+  if (can_use_cudnn_channels_last_2d) {
+    return at::MemoryFormat::ChannelsLast;
+  }
+
+  bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
+    (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
+    (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
+  );
+  if (can_use_cudnn_channels_last_3d) {
+    return at::MemoryFormat::ChannelsLast3d;
+  }
+
+  return at::MemoryFormat::Contiguous;
+}
+
+// controls whether emptyCache will be called following cudnn conv benchmarking
+TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
+TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
+
+
+inline at::MemoryFormat miopen_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
+  // disable NHWC for float64 input.
+  if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
+      input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return at::MemoryFormat::Contiguous;
+  }
+
+  // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
+  // See https://github.com/pytorch/pytorch/issues/64427.
+  // non static variable is used to be able to change environment variable in runtime for testing
+  // enabled by default for ROCm >= 7.0.0 with miopen 3.5
+  int miopen_version = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() : 0;
+  bool is_miopen_3_5 = miopen_version >= 30500;  // ROCm 7.0
+  bool suggest_nhwc = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(is_miopen_3_5);
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+  auto weight_ndim = weight.ndimension();
+
+  bool can_use_miopen_channels_last_2d = suggest_nhwc && (weight_ndim == 4) && (
+    (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+    (weight_memory_format == at::MemoryFormat::ChannelsLast)
+  );
+  if (can_use_miopen_channels_last_2d) {
+    return at::MemoryFormat::ChannelsLast;
+  }
+
+  bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && (
+    (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
+    (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
+  );
+  if (can_use_miopen_channels_last_3d) {
+    return at::MemoryFormat::ChannelsLast3d;
+  }
+
+  return at::MemoryFormat::Contiguous;
+}
+
+// deprecated, but to remove would be BC-breaking
+inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+  return miopen_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous;
+}
+
+inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  // disable NHWC for float64 input.
+  if (input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return false;
+  }
+
+  // disable NHWC for MkldnnCPU tensor.
+  if (input.is_mkldnn() || weight.is_mkldnn()) {
+    return false;
+  }
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+
+  bool can_use_mkldnn_channels_last_2d =
+      (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+      (weight_memory_format == at::MemoryFormat::ChannelsLast);
+
+  bool can_use_mkldnn_channels_last_3d =
+      (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
+      (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
+
+  return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
+}
+
+inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+
+  bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
+      (input_memory_format  == at::MemoryFormat::ChannelsLast) || (
+       weight_memory_format == at::MemoryFormat::ChannelsLast));
+
+  return can_use_thnn_channels_last_2d;
+}
+
+inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  // check layout only for xpu tensor.
+  if (!input.is_xpu() || !weight.is_xpu()) {
+    return false;
+  }
+  if (!input.defined() || input.is_sparse()) {
+    // suggest channels_first
+    return false;
+  }
+
+  auto is_channel_last = [](const at::Tensor& t) {
+    auto fmt = t.suggest_memory_format();
+    return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d;
+  };
+  return is_channel_last(input) || is_channel_last(weight);
+}
+
+inline bool mps_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  // check layout only for mps tensor.
+  if (!input.is_mps() || !weight.is_mps()) {
+    return false;
+  }
+  if (!input.defined() || input.is_sparse()) {
+    // suggest channels_first
+    return false;
+  }
+
+  auto is_channel_last = [](const at::Tensor& t) {
+    auto fmt = t.suggest_memory_format();
+    return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d;
+  };
+  return is_channel_last(input) || is_channel_last(weight);
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Cross.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Cross.h
new file mode 100644
index 0000000000000000000000000000000000000000..2100156a2dd23e0da0df84bcb23bd21b173a2bf5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Cross.h
@@ -0,0 +1,19 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d);
+
+DECLARE_DISPATCH(cross_fn, cross_stub)
+
+}} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DispatchStub.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DispatchStub.h
new file mode 100644
index 0000000000000000000000000000000000000000..a57093238a8beea1b2d285ff17db067865eb6992
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DispatchStub.h
@@ -0,0 +1,500 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+// Implements instruction set specific function dispatch.
+//
+// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
+// compiled multiple times with different compiler flags (e.g. -mavx2). A
+// DispatchStub contains a table of function pointers for a kernel. At runtime,
+// the fastest available kernel is chosen based on the features reported by
+// cpuinfo.
+//
+// Example:
+//
+// In native/MyKernel.h:
+//   using fn_type = void(*)(const Tensor& x);
+//   DECLARE_DISPATCH(fn_type, stub)
+//
+// In native/MyKernel.cpp
+//   DEFINE_DISPATCH(stub);
+//
+// In native/cpu/MyKernel.cpp:
+//   namespace {
+//     // use anonymous namespace so that different cpu versions won't conflict
+//     void kernel(const Tensor& x) { ... }
+//   }
+//   REGISTER_DISPATCH(stub, &kernel);
+//
+// To call:
+//   stub(kCPU, tensor);
+//
+// TODO: CPU instruction set selection should be folded into whatever
+// the main dispatch mechanism is.
+//
+// Supported device types for registration:
+//   - CPU: Central Processing Unit
+//   - CUDA: NVIDIA GPUs
+//   - HIP: AMD GPUs
+//   - MPS: Apple Silicon GPUs (Metal Performance Shaders)
+//   - MTIA: Meta Training and Inference Devices
+//   - XPU: Intel GPUs
+//   - HPU: Reserved for HPU (Intel Gaudi) device types
+//   - PrivateUse1: Reserved for private/custom device types
+//
+// If you want to update the list of supported devices, add a new dispatch_ptr
+// member in DispatchStubImpl.h and update the get_call_ptr switch.
+// As well you will need to update the inlined list in 'is_device_supported`
+//
+//
+// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
+C10_CLANG_DIAGNOSTIC_PUSH()
+C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
+
+namespace at::native {
+
+enum class CPUCapability {
+  DEFAULT = 0,
+#if defined(HAVE_VSX_CPU_DEFINITION)
+  VSX = 1,
+#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
+  ZVECTOR = 1,
+#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
+  SVE256 = 1,
+#else
+  AVX2 = 1,
+  AVX512 = 2,
+#endif
+  NUM_OPTIONS
+};
+
+// Enum for error types
+enum class ErrorType {
+  MissingDeviceKernel,
+  DeviceNotSupported
+};
+
+// Alias for the return type using std::variant
+using DispatchResult = std::variant;
+
+CPUCapability get_cpu_capability();
+
+template 
+struct DispatchStub;
+
+/**
+ * The sole purpose of this class is to outline methods that don't need to be
+ * specialized or otherwise inlined and duplicated (by the compiler due to
+ * template expansion), since it causes size bloat if there are a significant
+ * number of specialization of the DispatchStub<> class.
+ */
+struct TORCH_API DispatchStubImpl {
+
+  // The DispatchStubImpl::try_get_call_ptr() method is used to get the call
+  // pointer for a given device type. If the call pointer is not found,
+  // DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
+  // The main difference between try_get_call_ptr() and get_call_ptr() is that
+  // try_get_call_ptr() will return the ErrorType and not raise an exception.
+  DispatchResult try_get_call_ptr(
+    c10::DeviceType device_type
+    , void *DEFAULT
+#ifdef HAVE_AVX512_CPU_DEFINITION
+      , void *AVX512
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+      , void *AVX2
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+      , void *VSX
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+      , void *ZVECTOR
+#endif
+#ifdef HAVE_SVE256_CPU_DEFINITION
+      , void *SVE256
+#endif
+  );
+
+  // Analogous to try_get_call_ptr(), but it will return the ErrorType and not
+  // raise an exception.
+  DispatchResult try_choose_cpu_impl(
+    void *DEFAULT
+#ifdef HAVE_AVX512_CPU_DEFINITION
+    , void *AVX512
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+    , void *AVX2
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+    , void *VSX
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+    , void *ZVECTOR
+#endif
+#ifdef HAVE_SVE256_CPU_DEFINITION
+    , void *SVE256
+#endif
+  );
+
+
+  void* get_call_ptr(
+    c10::DeviceType device_type
+    , void *DEFAULT
+#ifdef HAVE_AVX512_CPU_DEFINITION
+      , void *AVX512
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+      , void *AVX2
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+      , void *VSX
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+      , void *ZVECTOR
+#endif
+#ifdef HAVE_SVE256_CPU_DEFINITION
+      , void *SVE256
+#endif
+  );
+
+  /**
+   * The CPU Dispatch actual method is chosen in decreasing order of preference by
+   * DispatchStubImpl::choose_cpu_impl() in case none is found by
+   * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
+   */
+  void* choose_cpu_impl(
+    void *DEFAULT
+#ifdef HAVE_AVX512_CPU_DEFINITION
+    , void *AVX512
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+    , void *AVX2
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+    , void *VSX
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+    , void *ZVECTOR
+#endif
+#ifdef HAVE_SVE256_CPU_DEFINITION
+    , void *SVE256
+#endif
+  );
+
+  // Fixing dispatch error in Windows debug builds.
+  // See https://github.com/pytorch/pytorch/issues/22681 for more details.
+  #if defined(_MSC_VER) && defined(_DEBUG)
+    std::atomic cpu_dispatch_ptr;
+    void* cuda_dispatch_ptr;
+    void* hip_dispatch_ptr;
+    void* mps_dispatch_ptr;
+    void* mtia_dispatch_ptr;
+  #if defined(USE_XPU)
+    void* xpu_dispatch_ptr;
+  #endif
+    void* hpu_dispatch_ptr;
+    void* privateuse1_dispatch_ptr;
+  #else
+    std::atomic cpu_dispatch_ptr{nullptr};
+    void* cuda_dispatch_ptr = nullptr;
+    void* hip_dispatch_ptr = nullptr;
+    void* mps_dispatch_ptr = nullptr;
+    void* mtia_dispatch_ptr = nullptr;
+  #if defined(USE_XPU)
+    void* xpu_dispatch_ptr = nullptr;
+  #endif
+    void* hpu_dispatch_ptr = nullptr;
+    void* privateuse1_dispatch_ptr = nullptr;
+  #endif
+};
+
+template 
+struct DispatchStub {
+  using FnPtr = rT (*) (Args...);
+
+  DispatchStub() = default;
+  DispatchStub(const DispatchStub&) = delete;
+  DispatchStub& operator=(const DispatchStub&) = delete;
+
+private:
+  FnPtr get_call_ptr(const c10::DeviceType device_type) {
+    return reinterpret_cast(
+      impl.get_call_ptr(device_type
+      , reinterpret_cast(DEFAULT)
+#ifdef HAVE_AVX512_CPU_DEFINITION
+      , reinterpret_cast(AVX512)
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+      , reinterpret_cast(AVX2)
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+      , reinterpret_cast(VSX)
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+      , reinterpret_cast(ZVECTOR)
+#endif
+#ifdef HAVE_SVE256_CPU_DEFINITION
+      , reinterpret_cast(SVE256)
+#endif
+      )
+    );
+  }
+
+public:
+  template 
+  rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
+    FnPtr call_ptr = get_call_ptr(device_type);
+    return (*call_ptr)(std::forward(args)...);
+  }
+
+  void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
+    impl.cuda_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  #if defined(USE_XPU)
+  void set_xpu_dispatch_ptr(FnPtr fn_ptr){
+    impl.xpu_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+  #endif
+
+  void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
+    impl.hpu_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  void set_hip_dispatch_ptr(FnPtr fn_ptr) {
+    impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  void set_mps_dispatch_ptr(FnPtr fn_ptr) {
+    impl.mps_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+    void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
+    impl.mtia_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
+    impl.privateuse1_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  // Returns true if the dispatcher has a kernel registered for this device
+  // type.
+  bool is_device_supported(const c10::DeviceType device_type) {
+    auto result = impl.try_get_call_ptr(device_type
+      , reinterpret_cast(DEFAULT)
+#ifdef HAVE_AVX512_CPU_DEFINITION
+      , reinterpret_cast(AVX512)
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+      , reinterpret_cast(AVX2)
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+      , reinterpret_cast(VSX)
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+      , reinterpret_cast(ZVECTOR)
+#endif
+#ifdef HAVE_SVE256_CPU_DEFINITION
+      , reinterpret_cast(SVE256)
+#endif
+      );
+    if (std::holds_alternative(result)){
+      return false;
+    }
+    return true;
+  }
+
+  static TORCH_API FnPtr DEFAULT;
+#ifdef HAVE_AVX512_CPU_DEFINITION
+  static TORCH_API FnPtr AVX512;
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+  static TORCH_API FnPtr AVX2;
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+  static TORCH_API FnPtr VSX;
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+  static TORCH_API FnPtr ZVECTOR;
+#endif
+#ifdef HAVE_SVE256_CPU_DEFINITION
+  static TORCH_API FnPtr SVE256;
+#endif
+private:
+  DispatchStubImpl impl;
+};
+
+namespace {
+template 
+struct RegisterCUDADispatch {
+  RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    stub.set_cuda_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterXPUDispatch {
+  RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
+    stub.set_xpu_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterHPUDispatch {
+  RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
+    stub.set_hpu_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterMPSDispatch {
+  RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    stub.set_mps_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterHIPDispatch {
+  RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    // TODO: make this point at hip_dispatch_ptr
+    stub.set_cuda_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterMTIADispatch {
+  RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    stub.set_mtia_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterPRIVATEUSE1Dispatch {
+  RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    stub.set_privateuse1_dispatch_ptr(value);
+  }
+};
+
+} // anonymous namespace
+// Compiler will complain if you put things like std::tuple in
+// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
+// adding parentheses and using helper struct to get rid of the parentheses, do
+// not work with MSVC. So do a `using`-declaration if you need to pass in such
+// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
+#define DECLARE_DISPATCH(fn, name)                                                         \
+  struct name##_DECLARE_DISPATCH_type : DispatchStub {   \
+    name##_DECLARE_DISPATCH_type() = default;                                              \
+    name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete;            \
+    name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
+    name##_DECLARE_DISPATCH_type(name##_DECLARE_DISPATCH_type&&) = delete;                 \
+    name##_DECLARE_DISPATCH_type& operator=(name##_DECLARE_DISPATCH_type&&) = delete;      \
+    ~name##_DECLARE_DISPATCH_type() = default;                                             \
+  };                                                                                       \
+  extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
+
+#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
+
+#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
+  template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub::arch = fn;
+
+#ifdef HAVE_AVX512_CPU_DEFINITION
+#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
+#else
+#define REGISTER_AVX512_DISPATCH(name, fn)
+#endif
+
+#ifdef HAVE_AVX2_CPU_DEFINITION
+#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
+#else
+#define REGISTER_AVX2_DISPATCH(name, fn)
+#endif
+
+#ifdef HAVE_VSX_CPU_DEFINITION
+#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
+#else
+#define REGISTER_VSX_DISPATCH(name, fn)
+#endif
+
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
+#else
+#define REGISTER_ZVECTOR_DISPATCH(name, fn)
+#endif
+
+#ifdef HAVE_SVE256_CPU_DEFINITION
+#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn)
+#else
+#define REGISTER_SVE256_DISPATCH(name, fn)
+#endif
+
+// Macro to register the same kernel for all CPU arch types. This is useful
+// if a kernel does not benefit from being recompiled across different arch types.
+#define REGISTER_ALL_CPU_DISPATCH(name, fn)                                    \
+  REGISTER_ARCH_DISPATCH(name, DEFAULT, fn)                                    \
+  REGISTER_AVX512_DISPATCH(name, fn)                                           \
+  REGISTER_AVX2_DISPATCH(name, fn)                                             \
+  REGISTER_VSX_DISPATCH(name, fn)                                              \
+  REGISTER_ZVECTOR_DISPATCH(name, fn)                                          \
+  REGISTER_SVE256_DISPATCH(name, fn)
+
+#define REGISTER_NO_CPU_DISPATCH(name)                                         \
+  REGISTER_ALL_CPU_DISPATCH(name, nullptr)
+
+#define REGISTER_CUDA_DISPATCH(name, fn) \
+  static RegisterCUDADispatch name ## __register(name, fn);
+
+#define REGISTER_XPU_DISPATCH(name, fn) \
+  static RegisterXPUDispatch name ## __register(name, fn);
+
+#define REGISTER_HPU_DISPATCH(name, fn) \
+  static RegisterHPUDispatch name ## __register(name, fn);
+
+#define REGISTER_HIP_DISPATCH(name, fn) \
+  static RegisterHIPDispatch name ## __register(name, fn);
+
+#define REGISTER_MPS_DISPATCH(name, fn) \
+  static RegisterMPSDispatch name ## __register(name, fn);
+
+#define REGISTER_MTIA_DISPATCH(name, fn) \
+  static RegisterMTIADispatch name ## __register(name, fn);
+
+#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
+  static RegisterPRIVATEUSE1Dispatch name ## __register(name, fn);
+
+// NB: This macro must be used in an actual 'cu' file; if you try using
+// it from a 'cpp' file it will not work!
+#if defined(__CUDACC__)
+#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
+#elif defined(__HIPCC__)
+// TODO: cut this over to HIP dispatch once we stop pretending that CUDA
+// is HIP in the PyTorch HIPify build.
+#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
+// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
+#elif defined(__OBJC__) && defined(USE_MPS)
+// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
+#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
+#elif defined(CPU_CAPABILITY)
+// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
+// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
+// ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others.
+#ifdef CPU_CAPABILITY_AVX512
+#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
+#else
+#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
+#endif
+#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
+#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
+#endif
+} // namespace at::native
+
+C10_CLANG_DIAGNOSTIC_POP()
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Distance.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Distance.h
new file mode 100644
index 0000000000000000000000000000000000000000..774434385595c7e3e3a35e1b956062f6156abb01
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Distance.h
@@ -0,0 +1,25 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
+using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
+using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
+using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
+
+DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub)
+DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub)
+DECLARE_DISPATCH(cdist_fn, cdist_stub)
+DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub)
+
+}} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Distributions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Distributions.h
new file mode 100644
index 0000000000000000000000000000000000000000..81dfbd07129a8efa07986e99e68665ef67c18961
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Distributions.h
@@ -0,0 +1,524 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+// ROCM hcc doesn't work well with using std:: in kernel functions
+#if defined(__CUDA_ARCH__)
+#include 
+#define compat_exp c10::cuda::compat::exp
+#define compat_ceil c10::cuda::compat::ceil
+#define compat_floor c10::cuda::compat::floor
+#define compat_log c10::cuda::compat::log
+#define compat_pow c10::cuda::compat::pow
+#define compat_sqrt c10::cuda::compat::sqrt
+#define compat_tan c10::cuda::compat::tan
+#define compat_abs c10::cuda::compat::abs
+#define compat_log1p c10::cuda::compat::log1p
+#elif defined(__HIPCC__)
+#include 
+#define compat_exp c10::hip::compat::exp
+#define compat_ceil c10::hip::compat::ceil
+#define compat_floor c10::hip::compat::floor
+#define compat_log c10::hip::compat::log
+#define compat_pow c10::hip::compat::pow
+#define compat_sqrt c10::hip::compat::sqrt
+#define compat_tan c10::hip::compat::tan
+#define compat_abs c10::hip::compat::abs
+#define compat_log1p c10::hip::compat::log1p
+#else
+#define compat_exp std::exp
+#define compat_ceil std::ceil
+#define compat_floor std::floor
+#define compat_log std::log
+#define compat_pow std::pow
+#define compat_sqrt std::sqrt
+#define compat_tan std::tan
+#define compat_abs std::abs
+#define compat_log1p std::log1p
+#endif
+
+namespace {
+
+#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
+// we cannot use std::isnan directly due to some incompatibility of
+// gcc constexpr'ing and nvcc
+using std::isnan;
+#endif
+
+// Here sampler_t should be function type scalar_t(void). For gpu
+// "sampler" is a device function, but since ROCM doesn't have
+// equivalent to nvstd::function, we use a template type parameter to
+// capture it.
+template
+struct BaseSampler {
+  sampler_t sampler;
+  C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
+  C10_DEVICE scalar_t sample() {
+    return sampler();
+  }
+};
+
+// The function `sample_gamma` is
+// is adapted from Numpy's distributions.c implementation.
+// It is MIT licensed, so here is the copyright:
+
+/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+*/
+
+template
+C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler& standard_uniform, BaseSampler& standard_normal) {
+  accscalar_t scale = 1.0f;
+
+  // Boost alpha for higher acceptance probability.
+  if (alpha < 1.0f) {
+    if (alpha == 0.f) return 0.f;
+    scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
+    alpha += 1.0f;
+  }
+
+  // This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
+  // doi:10.1145/358407.358414
+  const accscalar_t d = alpha - 1.0f / 3.0f;
+  const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
+  for (;;) {
+    accscalar_t x, y;
+    do {
+      x = standard_normal.sample();
+      y = 1.0f + c * x;
+    } while (y <= 0);
+    const accscalar_t v = y * y * y;
+    const accscalar_t u = 1 - standard_uniform.sample();
+    const accscalar_t xx = x * x;
+    if (u < 1.0f - 0.0331f * xx * xx)
+      return static_cast(scale * d * v);
+    if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
+      return static_cast(scale * d * v);
+  }
+}
+
+/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
+ * from TensorFlow's random_binomial_op.cc implementation. That code is under
+ * copyright: 2019 The TensorFlow Authors.
+ *
+ * It was released under the Apache License, Version 2.0 (the "License"), available at:
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+template
+C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
+  constexpr static scalar_t kTailValues[] = {
+    0.0810614667953272,
+    0.0413406959554092,
+    0.0276779256849983,
+    0.02079067210376509,
+    0.0166446911898211,
+    0.0138761288230707,
+    0.0118967099458917,
+    0.0104112652619720,
+    0.00925546218271273,
+    0.00833056343336287
+  };
+  if (k < std::size(kTailValues)) {
+    return kTailValues[static_cast(k)];
+  }
+  scalar_t kp1sq = (k + 1) * (k + 1);
+  return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
+}
+
+
+template
+C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) {
+  accscalar_t U;
+  accscalar_t geom_sum = 0;
+  scalar_t num_geom = 0;
+
+  accscalar_t logprob = compat_log1p(-prob);
+
+  while (true) {
+    U = standard_uniform.sample();
+    accscalar_t geom = compat_ceil(compat_log(U) / logprob);
+    geom_sum += geom;
+    if (geom_sum > count) {
+      break;
+    }
+    num_geom = num_geom + 1;
+  }
+  return num_geom;
+}
+
+template
+C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) {
+  scalar_t k;
+  accscalar_t U, V, us;
+
+  // This is spq in the paper.
+  const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
+
+  // Other coefficients for Transformed Rejection sampling.
+  const accscalar_t b = 1.15 + 2.53 * stddev;
+  const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
+  const accscalar_t c = count * prob + 0.5;
+  const accscalar_t v_r = 0.92 - 4.2 / b;
+  const accscalar_t r = prob / (1 - prob);
+
+  const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
+  const accscalar_t m = compat_floor((count + 1) * prob);
+
+  while (true) {
+    U = standard_uniform.sample() - 0.5;
+    V = standard_uniform.sample();
+
+    us = 0.5 - compat_abs(U);
+    k = static_cast(compat_floor((2 * a / us + b) * U + c));
+
+    // Reject non-sensical answers.
+    if (k < 0 || k > count) {
+      continue;
+    }
+    // Region for which the box is tight, and we can return our calculated value.
+    // This should happen 0.86 * v_r times. In the limit as n * p is large,
+    // the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
+    if (us >= 0.07 && V <= v_r) {
+      return k;
+    }
+
+    // This deviates from Hormann's BTRS algorithm, as there is a log missing.
+    // For all (u, v) pairs outside of the bounding box, this calculates the
+    // transformed-reject ratio.
+    V = compat_log(V * alpha / (a / (us * us) + b));
+    accscalar_t upperbound =
+        ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
+         (count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
+         (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
+         stirling_approx_tail(m) + stirling_approx_tail(count - m) -
+         stirling_approx_tail(k) - stirling_approx_tail(count - k));
+
+    if (V <= upperbound) {
+      return k;
+    }
+  }
+}
+
+template
+C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) {
+  if (count <= 0.0 || prob <= 0.0) {
+    return 0;
+  } else if (prob >= 1.0) {
+    return count;
+  } else if (prob <= 0.5) {
+    if (count * prob >= 10.0) {
+      // btrs
+      return btrs(count, prob, standard_uniform);
+    } else {
+      // binomial inversion
+      return binomial_inversion(count, prob, standard_uniform);
+    }
+  } else if (prob > 0.5) {
+    scalar_t qprob = 1.0 - prob;
+    if (count * qprob >= 10.0) {
+      // btrs
+      return count - btrs(count, qprob, standard_uniform);
+    } else {
+      // count - binomial inversion
+      return count - binomial_inversion(count, qprob, standard_uniform);
+    }
+  } else {
+    // prob is nan?
+    return static_cast(NAN);
+  }
+}
+
+/*
+ * This function is derived from the implementation of the digamma function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
+ */
+template
+C10_DEVICE inline scalar_t digamma_one(scalar_t x) {
+  constexpr accscalar_t PSI_10 = 2.25175258906672110764;
+  if (x == 0) {
+    return INFINITY;
+  }
+  accscalar_t additional_summand = 0;
+  int x_is_integer = x == compat_floor(x);
+  if (x < 0) {
+    if (x_is_integer) {
+      return INFINITY;
+    }
+    // it is more standard to write this as recursion, but
+    // nvcc does not like that
+    additional_summand = -c10::pi /
+        compat_tan(c10::pi * x);
+    x = 1 - x;
+  }
+
+  // Push x to be >= 10
+  accscalar_t result = 0;
+  while (x < 10) {
+    result -= 1 / x;
+    x += 1;
+  }
+  if (x == 10) {
+    return result + PSI_10 + additional_summand;
+  }
+
+  // Compute asymptotic digamma
+  static const accscalar_t A[] = {
+     8.33333333333333333333E-2,
+    -2.10927960927960927961E-2,
+     7.57575757575757575758E-3,
+    -4.16666666666666666667E-3,
+     3.96825396825396825397E-3,
+    -8.33333333333333333333E-3,
+     8.33333333333333333333E-2,
+  };
+
+  accscalar_t y = 0;
+  if (x < 1.0e17f) {
+    accscalar_t z = 1.0 / (x * x);
+    y = z * polevl(z, A, 6);
+  }
+  return static_cast(
+      result + compat_log(x) - (0.5f / x) - y + additional_summand);
+}
+
+// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
+// for random number x drawn from a standard Gamma distribution Gamma(alpha).
+template 
+C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
+  // Use a Taylor series expansion for small x.
+  accscalar_t x = static_cast(x_);
+  accscalar_t alpha = static_cast(alpha_);
+  if (x < 0.8f) {
+    accscalar_t numer = 1;
+    accscalar_t denom = alpha;
+    auto series1 = numer / denom;
+    auto series2 = numer / (denom * denom);
+    for (int i = 1; i <= 5; ++i) {
+      numer *= -x / static_cast(i);
+      denom += 1;
+      series1 += numer / denom;
+      series2 += numer / (denom * denom);
+    }
+    const auto pow_x_alpha = compat_pow(x, alpha);
+    const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
+    const auto gamma_cdf = pow_x_alpha * series1;
+    const auto gamma_cdf_alpha =
+        (compat_log(x) - digamma_one(alpha)) *
+            gamma_cdf -
+        pow_x_alpha * series2;
+    const auto result = -gamma_cdf_alpha / gamma_pdf;
+    return isnan(result) ? static_cast( 0.f ) : static_cast(result);
+  }
+
+  // Use a Rice saddle point expansion for large alpha.
+  if (alpha > 8.0f) {
+    if (0.9f * alpha <= x && x <= 1.1f * alpha) {
+      const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
+      const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
+          - 65 * x * x / alpha + alpha * (107 + 3600 * x);
+      const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
+      return static_cast(numer_1 * numer_2 / denom);
+    }
+    const auto denom = compat_sqrt(8 * alpha);
+    const auto term2 = denom / (alpha - x);
+    const auto term3 = compat_pow(
+        x - alpha - alpha * compat_log(x / alpha),
+        static_cast(-1.5));
+    const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
+    const auto term1 = compat_log(x / alpha) * term23 -
+        compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
+    const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
+    const auto numer = x * term1;
+    return static_cast(-stirling * numer / denom);
+  }
+
+  // Use a bivariate rational approximation to the reparameterized gradient.
+  const auto u = compat_log(x / alpha);
+  const auto v = compat_log(alpha);
+  static const accscalar_t coef_uv[3][8] = {
+    {0.16009398, -0.094634809, 0.025146376, -0.0030648343,
+     1, 0.32668115, 0.10406089, 0.0014179084},
+    {0.53487893, 0.1298071, 0.065735949, -0.0015649758,
+     0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
+    {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
+     0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
+  };
+  accscalar_t coef_v[8];
+  for (int i = 0; i < 8; ++ i) {
+    coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
+  }
+  const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
+  const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
+  return static_cast(compat_exp(p / q));
+}
+
+// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
+// Assumes x is close to zero and uses a Taylor expansion.
+template 
+C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
+  const scalar_t factor = digamma_one(alpha)
+                        - digamma_one(alpha + beta) - compat_log(x);
+  scalar_t numer = 1;
+  scalar_t series = numer / alpha * (factor + 1 / alpha);
+  for (int i = 1; i <= 10; ++i) {
+    scalar_t casted_i = static_cast(i);
+    numer *= (casted_i - beta) * x / casted_i;
+    const scalar_t denom = alpha + casted_i;
+    series += numer / denom * (factor + 1 / denom);
+  }
+  const scalar_t result = x * compat_pow(1 - x, -beta) * series;
+  return isnan(result) ? static_cast( 0.f ) : result;
+}
+
+// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
+// Assumes x is close to zero and uses a Taylor expansion.
+template 
+C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
+  const scalar_t factor = digamma_one(alpha + beta) - digamma_one(beta);
+  scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
+  for (int i = 1; i <= 8; ++i) {
+    scalar_t casted_i = static_cast(i);
+    numer *= -x / casted_i;
+    dbetas = dbetas * (beta - casted_i) + betas;
+    betas = betas * (beta - casted_i);
+    series += numer / (alpha + casted_i) * (dbetas + factor * betas);
+  }
+  const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
+  return isnan(result) ? static_cast( 0.f ) : result;
+}
+
+// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
+// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
+// To ensure numerical stability, this computation is performed at higher precision.
+template
+C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
+  const accscalar_t total = alpha + beta;
+  const accscalar_t mean = alpha / total;
+  const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
+  if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
+    // Avoid the singularity at x = mean.
+    const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
+                           (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
+                           3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
+                           (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
+                           8 * (1 - x) * (135 * beta - 11)))));
+    const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
+    const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
+    return prefactor_num / (1 - x) * poly / prefactor_den;
+  }
+  const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
+  const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
+                             * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
+                             / (1 + 1 / (12 * total) + 1 / (288 * total * total));
+  const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
+  const accscalar_t axbx = alpha * (x - 1) + beta * x;
+  const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast(1.5f)) * axbx * axbx;
+  const accscalar_t term1 = term1_num / term1_den;
+  const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
+  const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
+  const accscalar_t term3_den = beta * x + alpha * (x - 1);
+  const accscalar_t term3 = term3_num / term3_den;
+  const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
+                               alpha * compat_log(alpha / (total * x));
+  const accscalar_t term4 = compat_pow(term4_base, static_cast(-1.5f));
+  const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
+  return static_cast(stirling * prefactor * term1234);
+}
+
+// Computes a scaled reparameterized gradient
+//   -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
+// for random number x drawn from a Beta distribution Beta(alpha,beta).
+// This function inputs total=alpha+beta to make it easy to implement
+// Dirichlet reparameterized gradients in terms of Betas.
+template
+C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
+  accscalar_t x_ = static_cast(x);
+  accscalar_t alpha_ = static_cast(alpha);
+  accscalar_t total_ = static_cast(total);
+
+  const scalar_t beta = total - alpha;
+  const accscalar_t beta_ = total_ - alpha_;
+  const scalar_t boundary = total * x * (1 - x);
+
+  // Use an asymptotic approximation for x close to 0.
+  if (x <= 0.5f && boundary < 2.5f) {
+    return _beta_grad_alpha_small(x, alpha, beta);
+  }
+
+  // Use an asymptotic approximation for x close to 1.
+  if (x >= 0.5f && boundary < 0.75f) {
+    return -_beta_grad_beta_small(1 - x, beta, alpha);
+  }
+
+  // Use an asymptotic approximation when alpha and (total - alpha) are both large.
+  if (alpha > 6 && beta > 6) {
+    return _beta_grad_alpha_mid(x_, alpha_, beta_);
+  }
+
+  // Use a rational correction to an analytic approximation.
+  static const accscalar_t c[2][3][3][4] = {
+    {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
+      {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
+      {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
+     {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
+      {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
+      {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
+     {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
+      {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
+      {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
+    {{{1, -0.02924021934, -0.04438342661, 0.007285809825},
+      {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
+      {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
+     {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
+      {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
+      {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
+     {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
+      {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
+      {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
+  };
+  const accscalar_t u = compat_log(x_);
+  const accscalar_t a = compat_log(alpha_) - u;
+  const accscalar_t b = compat_log(total_) - a;
+  const accscalar_t pow_u[3] = {1, u, u * u};
+  const accscalar_t pow_a[3] = {1, a, a * a};
+  accscalar_t p = 0.0;
+  accscalar_t q = 0.0;
+  for (int i = 0; i < 3; ++i) {
+    for (int j = 0; j < 3; ++j) {
+      const accscalar_t ua = pow_u[i] * pow_a[j];
+      p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
+      q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
+    }
+  }
+  const accscalar_t approx = x_ * (digamma_one(total_) - digamma_one(alpha_)) / beta_;
+  return static_cast(p / q * approx);
+}
+
+} // namespace
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/EmbeddingBag.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/EmbeddingBag.h
new file mode 100644
index 0000000000000000000000000000000000000000..9d7569f13a7b62ed86f2d055158712258edc98fe
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/EmbeddingBag.h
@@ -0,0 +1,159 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+#include 
+#include 
+
+#ifdef USE_FBGEMM
+#include 
+#endif
+
+namespace at::native {
+
+enum class EmbeddingBagMode {
+  SUM = 0,
+  MEAN = 1,
+  MAX = 2,
+};
+
+[[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) {
+  return op1 == static_cast(op2);
+}
+
+[[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) {
+  return !(op1 == op2);
+}
+
+void check_arguments(
+    const Tensor& weight,
+    const Tensor& indices,
+    const Tensor& offsets,
+    const int64_t mode,
+    const std::optional& per_sample_weights,
+    bool include_last_offset);
+
+void make_bag_size_out(
+    Tensor& bag_size_out,
+    const Tensor& offsets,
+    const Tensor& indices,
+    const int64_t mode,
+    const bool include_last_offset,
+    const bool requires_grad);
+
+void make_max_indices_out(
+    Tensor& max_indices_out,
+    const Tensor& weight,
+    const Tensor& indices,
+    const Tensor& offsets,
+    const Tensor& bag_size,
+    const int64_t mode,
+    bool include_last_offset);
+
+void make_offset2bag_out(
+    Tensor& offset2bag,
+    Tensor& output,
+    const Tensor& weight,
+    const Tensor& indices,
+    const Tensor& offsets,
+    const int64_t mode,
+    const std::optional& per_sample_weights,
+    const int64_t padding_idx = -1);
+
+#ifdef USE_FBGEMM
+
+template
+struct _CallbackAndBlockSize {
+    using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature::Type;
+
+    int64_t blockSize = -1;
+    TCallback callback = nullptr;
+
+    static TCallback generateCallback(int64_t block_size) {
+        return fbgemm::GenerateEmbeddingSpMDM(
+                block_size,
+                has_weight,
+                /* normalize_by_lengths */false,
+                /* prefetch */16,
+                /* is_weight_positional */false,
+                /* use_offsets */true);
+    }
+
+    _CallbackAndBlockSize() = default;
+
+    explicit _CallbackAndBlockSize(std::optional maybe_block_size)
+      : blockSize(maybe_block_size.value_or(-1))
+      , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
+    {}
+};
+
+template
+struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
+
+    _EmbeddingBagKernelCacheImpl() = default;
+    // use each of the mixins to store corresponding kernel and block size
+    explicit _EmbeddingBagKernelCacheImpl(std::optional maybe_block_size)
+      : StorageMixins(maybe_block_size)...
+    {}
+
+    // this method is thread safe (call sites may call from different threads)
+    template
+    typename _CallbackAndBlockSize::TCallback
+    getCallback(int64_t block_size) const {
+        // if the cache doesn't store the kernel for the incoming block size
+        // (so it is different from the one stored in corresponding mixin)
+        // regenerate the kernel (not writing it into the cache so we avoid locks)
+        if (block_size != _CallbackAndBlockSize::blockSize) {
+            return _CallbackAndBlockSize::generateCallback(block_size);
+        }
+        // else retrieve the cached kernel from the corresponding mixin
+        return _CallbackAndBlockSize::callback;
+    }
+};
+
+// instantiate the cache with the list of storage mixins
+// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
+using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize>;
+#else
+struct _EmbeddingBagKernelCache {
+    explicit _EmbeddingBagKernelCache(std::optional /* maybe_block_size */) {}
+};
+#endif
+
+void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
+    Tensor& bag_size, Tensor* max_indices,
+    const Tensor &weight, const Tensor &indices,
+    const Tensor &offsets, const int64_t mode = 0,
+    const std::optional& per_sample_weights = std::nullopt,
+    bool include_last_offset = false,
+    int64_t padding_idx = -1,
+    _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
+
+void _embedding_bag_cpu_out(
+    at::Tensor& output,
+    at::Tensor& offset2bag,
+    at::Tensor& bag_size,
+    at::Tensor* p_max_indices,
+    const at::Tensor& weight,
+    const at::Tensor& indices,
+    const at::Tensor& offsets,
+    const bool scale_grad_by_freq,
+    const int64_t mode,
+    const bool sparse,
+    const std::optional& per_sample_weights,
+    const bool include_last_offset,
+    const std::optional& padding_idx,
+    _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..a6d26c2e1d0286851018077475a9fead83df528f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h
@@ -0,0 +1,25 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+struct TensorIterator;
+
+namespace native {
+
+using _compute_linear_combination_fn = void(*)(
+  TensorIterator& iter,
+  int64_t in_stride,
+  int64_t coeff_stride,
+  int64_t num_summations
+);
+
+DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub)
+
+}} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GridSampler.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GridSampler.h
new file mode 100644
index 0000000000000000000000000000000000000000..bc561018ec5944b7cc9d118aeb220eee8e896591
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GridSampler.h
@@ -0,0 +1,303 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at::native {
+
+using detail::GridSamplerInterpolation;
+using detail::GridSamplerPadding;
+
+// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
+// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
+// if align_corners: -1 and +1 get sent to the centers of the corner pixels
+//     -1 --> 0
+//     +1 --> (size - 1)
+//     scale_factor = (size - 1) / 2
+// if not align_corners: -1 and +1 get sent to the image edges
+//     -1 --> -0.5
+//     +1 --> (size - 1) + 0.5 == size - 0.5
+//     scale_factor = size / 2
+template 
+static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
+                                                bool align_corners) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    return ((coord + 1) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    return ((coord + 1) * size - 1) / 2;
+  }
+}
+
+// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
+// except that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
+                                                         bool align_corners, scalar_t *grad_in) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    *grad_in = static_cast(size - 1) / 2;
+    return ((coord + 1) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    *grad_in = static_cast(size) / 2;
+    return ((coord + 1) * size - 1) / 2;
+  }
+}
+
+// Clips coordinates to between 0 and clip_limit - 1
+template
+static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
+  return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0)));
+}
+
+// clip_coordinates_set_grad works similarly to clip_coordinates except that
+// it also returns the `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template
+static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
+                                                 scalar_t *grad_in) {
+  // Note that it is important for the gradient calculation that borders
+  // are considered out of bounds.
+  if (in <= static_cast(0)) {
+    *grad_in = static_cast(0);
+    return static_cast(0);
+  } else {
+    scalar_t max = static_cast(clip_limit - 1);
+    if (in >= max) {
+      *grad_in = static_cast(0);
+      return max;
+    } else {
+      *grad_in = static_cast(1);
+      return in;
+    }
+  }
+}
+
+// Reflects coordinates until they fall between low and high (inclusive).
+// The bounds are passed as twice their value so that half-integer values
+// can be represented as ints.
+template
+static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
+                                           int64_t twice_high) {
+  if (twice_low == twice_high) {
+    return static_cast(0);
+  }
+  scalar_t min = static_cast(twice_low) / 2;
+  scalar_t span = static_cast(twice_high - twice_low) / 2;
+  in = std::fabs(in - min);
+  // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
+  scalar_t extra = std::fmod(in, span);
+  int flips = static_cast(std::floor(in / span));
+  if (flips % 2 == 0) {
+    return extra + min;
+  } else {
+    return span - extra + min;
+  }
+}
+
+// reflect_coordinates_set_grad works similarly to reflect_coordinates except
+// that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template
+static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
+                                                    int64_t twice_high, scalar_t *grad_in) {
+  if (twice_low == twice_high) {
+    *grad_in = static_cast(0);
+    return static_cast(0);
+  }
+  int grad_in_mult_;
+  scalar_t min = static_cast(twice_low) / 2;
+  scalar_t span = static_cast(twice_high - twice_low) / 2;
+  in = in - min;
+  if (in < static_cast(0)) {
+    grad_in_mult_ = -1;
+    in = -in;
+  } else {
+    grad_in_mult_ = 1;
+  }
+  // `fmod` returns same sign as `in`, which is positive after the `if` above.
+  scalar_t extra = std::fmod(in, span);
+  int flips = static_cast(std::floor(in / span));
+  if (flips % 2 == 0) {
+    *grad_in = static_cast(grad_in_mult_);
+    return extra + min;
+  } else {
+    *grad_in = static_cast(-grad_in_mult_);
+    return span - extra + min;
+  }
+}
+
+// Mapping the out-of-boundary points back into boundary
+// This would only affect padding_mode=border or reflection
+template
+static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
+                                           GridSamplerPadding padding_mode,
+                                           bool align_corners) {
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates(coord, 0, 2*(size - 1));
+    } else {
+      coord = reflect_coordinates(coord, -1, 2*size - 1);
+    }
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  }
+  return coord;
+}
+
+// Computes the pixel source index value for a grid coordinate
+template 
+static inline scalar_t grid_sampler_compute_source_index(
+    scalar_t coord,
+    int64_t size,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+  coord = grid_sampler_unnormalize(coord, size, align_corners);
+  coord = compute_coordinates(coord, size, padding_mode, align_corners);
+  return coord;
+}
+
+// grid_sampler_compute_source_index_set_grad works similarly to
+// grid_sampler_compute_source_index except that it also returns the
+// `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static inline scalar_t grid_sampler_compute_source_index_set_grad(
+    scalar_t coord,
+    int64_t size,
+    GridSamplerPadding padding_mode,
+    bool align_corners,
+    scalar_t *grad_in) {
+  scalar_t grad_clip, grad_refl;
+  coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_clip;
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
+    } else {
+      coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
+    }
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_refl * grad_clip;
+  }
+  return coord;
+}
+
+static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
+  return h >= 0 && h < H && w >= 0 && w < W;
+}
+
+static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
+  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
+}
+
+template
+static inline scalar_t get_value_bounded(
+    const scalar_t* data,
+    scalar_t x,
+    scalar_t y,
+    int64_t W,
+    int64_t H,
+    int64_t sW,
+    int64_t sH,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+
+  x = compute_coordinates(x, W, padding_mode, align_corners);
+  y = compute_coordinates(y, H, padding_mode, align_corners);
+
+  int64_t ix = static_cast(x);
+  int64_t iy = static_cast(y);
+
+  if (within_bounds_2d(iy, ix, H, W)) {
+    return data[iy * sH + ix * sW];
+  }
+  return static_cast(0);
+}
+
+template
+static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
+                               int64_t sH, int64_t sW, int64_t H, int64_t W,
+                               scalar_t delta) {
+  if (within_bounds_2d(h, w, H, W)) {
+    data[h * sH + w * sW] += delta;
+  }
+}
+
+template
+static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
+                               int64_t sD, int64_t sH, int64_t sW,
+                               int64_t D, int64_t H, int64_t W,
+                               scalar_t delta) {
+  if (within_bounds_3d(d, h, w, D, H, W)) {
+    data[d * sD + h * sH + w * sW] += delta;
+  }
+}
+
+template
+static inline void add_value_bounded(
+    scalar_t* data,
+    scalar_t x,
+    scalar_t y,
+    int64_t W,
+    int64_t H,
+    int64_t sW,
+    int64_t sH,
+    scalar_t delta,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+
+  x = compute_coordinates(x, W, padding_mode, align_corners);
+  y = compute_coordinates(y, H, padding_mode, align_corners);
+
+  int64_t ix = static_cast(x);
+  int64_t iy = static_cast(y);
+
+  safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
+}
+
+// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
+template
+static inline void get_cubic_coefficients_grad(
+    scalar_t coeffs[4],
+    scalar_t t) {
+
+  // Must be the same as forward calculation in
+  // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
+  scalar_t A = -0.75;
+
+  scalar_t x;
+  x = -1 - t; // 1 < x = |-1 - tx| < 2
+  coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
+  x = -t;     // x = |0 - tx| <= 1
+  coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
+  x = 1 - t;  // x = |1 - tx| <= 1
+  coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
+  x = 2 - t;  // 1 < x = |2 - tx| < 2
+  coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
+}
+
+}  // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/IndexingUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/IndexingUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..f462528e7ffbfa9f87b5cacf8f6158948e8e1df6
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/IndexingUtils.h
@@ -0,0 +1,186 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#endif
+
+namespace at::native {
+
+[[noreturn]]
+static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
+  TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
+  " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
+}
+
+[[maybe_unused]] static std::vector expandTensors(
+    const Tensor& self,
+    IOptTensorListRef indices,
+    bool ensure_same_device = false) {
+  // If indices come in as ByteTensor or BoolTensor (masks), expand them into
+  // the equivalent indexing by LongTensors
+  std::vector result;
+  for (const auto& index_opt : indices) {
+    if (!index_opt.has_value()) {
+      result.emplace_back();
+    } else {
+      const auto& index = *index_opt;
+      if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
+        if (index.scalar_type() == kByte) {
+          TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
+          " please use a dtype torch.bool instead.");
+        }
+        // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
+        // corresponding dimensions in self
+        for (const auto j : c10::irange(index.dim())) {
+          int64_t srcIdx = static_cast(result.size() + j);
+          if (index.size(j) != self.size(srcIdx)) {
+            invalid_mask(self, srcIdx, index, j);
+          }
+        }
+        // Replace with nonzeros
+        at::Tensor nonzero;
+        if (ensure_same_device && index.device() != self.device()) {
+          bool non_blocking = index.is_cpu() && self.device().is_cuda();
+          auto out = at::empty({0}, index.options().dtype(kLong).pinned_memory(non_blocking));
+          nonzero = at::nonzero_out(out, index).to(self.device(), non_blocking);
+        } else {
+          nonzero = index.nonzero();
+        }
+        for (const auto j : c10::irange(index.dim())) {
+          result.emplace_back(nonzero.select(1, j));
+        }
+      } else if (ensure_same_device && index.device() != self.device()) {
+        result.emplace_back(index.to(self.device()));
+      } else {
+        result.emplace_back(index);
+      }
+    }
+  }
+  return result;
+}
+
+[[maybe_unused]] static void checkIndexTensorTypes(
+    IOptTensorListRef indices,
+    bool allow_int = false) {
+  for (const auto& tensor : indices) {
+    if (tensor.has_value() && tensor->defined()) {
+      auto scalarType = tensor->scalar_type();
+      if (allow_int) {
+        if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
+            TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
+        }
+      } else {
+        if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
+            TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
+        }
+      }
+    }
+  }
+}
+
+inline torch::List> toListOfOptionalTensors(ArrayRef list) {
+  torch::List> result;
+  result.reserve(list.size());
+  for (const Tensor& a : list) {
+    result.push_back(a);
+  }
+  return result;
+}
+
+inline torch::List> toListOfOptionalTensors(ArrayRef list) {
+  torch::List> result;
+  result.reserve(list.size());
+  for (const IValue& a : list) {
+    result.push_back(a.isTensor() ? std::optional(a.toTensor()) : std::optional());
+  }
+  return result;
+}
+
+[[maybe_unused]] static bool hasContiguousSubspace(TensorList tl) {
+  // true if all the non-null tensors are adjacent
+  auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
+  auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
+  auto start = std::find_if(tl.begin(), tl.end(), isDefined);
+  auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
+  auto it = std::find_if(start, stop.base(), isNull);
+  return it == stop.base();
+}
+
+// Transposes the tensor and indices together so that all the non-null indices
+// index the first k dimensions of the tensor. Returns the transposed tensor
+// and the reordered indices. For example:
+// transposeToFront(tensor, {nullptr, a, nullptr, b})
+// returns
+// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
+[[maybe_unused]] static std::tuple> transposeToFront(
+    const Tensor& self,
+    TensorList indices) {
+  std::vector dims;
+  std::vector transposedIndices;
+  dims.reserve(self.dim());
+  for (const auto i : c10::irange(self.dim())) {
+    if (indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back(indices[i]);
+    }
+  }
+  for (const auto i : c10::irange(self.dim())) {
+    if (!indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back();
+    }
+  }
+  return std::make_tuple(self.permute(dims), std::move(transposedIndices));
+}
+
+inline std::tuple, std::vector>
+transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
+  std::vector dims;
+  std::vector invPerm;
+  std::vector transposedIndices;
+  dims.reserve(self.dim());
+  invPerm.resize(self.dim());
+  for (const auto i : c10::irange(self.dim())) {
+    if (indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back(indices[i]);
+    }
+  }
+  for (const auto i : c10::irange(self.dim())) {
+    if (!indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back();
+    }
+  }
+  for (const auto i : c10::irange(self.dim())) {
+    invPerm[dims[i]] = i;
+  }
+  return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
+}
+
+struct AdvancedIndex {
+  AdvancedIndex(const Tensor& src, TensorList indices);
+
+  Tensor src;
+  std::vector indices;
+  DimVector indexed_sizes;
+  DimVector indexed_strides;
+  int64_t dims_before;
+  int64_t dims_after;
+};
+
+
+} //namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Lerp.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Lerp.h
new file mode 100644
index 0000000000000000000000000000000000000000..9a51a5c712a9dacee54261aafa9e78d07445de67
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Lerp.h
@@ -0,0 +1,51 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+template 
+C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
+  return std::abs(weight) < scalar_t(0.5);
+}
+template 
+C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex weight) {
+  // Avoid the sqrt in abs(weight)
+  return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
+}
+
+template 
+C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
+  using opmath_t = at::opmath_type;
+  using opmath_weight_t = at::opmath_type;
+
+  opmath_t self = self_;
+  opmath_t end = end_;
+  opmath_weight_t weight = weight_;
+
+  // Conditional for better numeric. This has been discussed in
+  // https://github.com/pytorch/pytorch/pull/18871
+  return is_lerp_weight_small(weight)
+      ? self + weight * (end - self)
+      : end - (end - self) * (opmath_t(1) - weight);
+}
+
+using lerp_fn_scalar = void (*)(
+    at::TensorIteratorBase& iter,
+    const Scalar& weight);
+
+using lerp_fn_tensor = void (*)(
+    at::TensorIteratorBase& iter);
+
+DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight)
+DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..47633d336ff50c372081000540c37dca16e17650
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h
@@ -0,0 +1,629 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#include 
+#include 
+#include 
+#endif
+
+namespace at::native {
+
+inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) {
+  if (tensor.is_conj()) {
+    return c10::MaybeOwned::owned(tensor.resolve_conj());
+  } else {
+    return c10::MaybeOwned::borrowed(tensor);
+  }
+}
+
+inline DimVector batched_matrix_contiguous_strides(
+    const IntArrayRef sizes,
+    const bool f_contig = false) {
+  // f_contig chooses between the strides of a batch of Fortran (F-contiguous)
+  // and C-contiguous matrices
+  auto strides = c10::contiguous_strides(sizes);
+  auto dim = strides.size();
+
+  if (f_contig && dim >= 2) {
+    // Fix the strides of the last two dimensions, so that we return
+    // C-contiguous batches of F-contiguous matrices.
+    strides[dim - 1] = std::max(sizes[dim - 2], static_cast(1));
+    strides[dim - 2] = 1;
+  }
+  return strides;
+}
+
+/*
+ * Clones a Tensor so that the following conditions hold:
+ * If we think of a Tensor of having size (B, M, N), where B is any number
+ * of batch dimensions, then:
+ * - Each (M, N) matrix is in column major form
+ * - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
+ *   Then when laid out in memory, the M by N matrix starting at
+ *   P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
+ *   matrix starting at Q.data_ptr()[B * M' * N'].
+ */
+inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
+  // If src is already in batched column major format, then
+  // this will be efficient (no reordering of the data will occur)
+  // because the first transpose will make the tensor contiguous,
+  // and cloning a contiguous tensor is fast.
+  auto result = src.mT().clone(at::MemoryFormat::Contiguous);
+  result.transpose_(-2, -1);
+  return result;
+}
+
+/*
+ * contig chooses between C-contig (true) and F-contig (false)
+ */
+inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
+  return cond ? c10::MaybeOwned::borrowed(borrow)
+              : c10::MaybeOwned::owned(contig ? clone.clone(MemoryFormat::Contiguous)
+                                                      : cloneBatchedColumnMajor(clone));
+}
+
+/*
+ * This method is designed to be a faster alternative to
+ * `cloneBatchedColumnMajor` with some additional features,
+ * namely:
+ * 1. It uses `copy` instead of `clone` which could be much faster.
+ * 2. `nrows` parameter used to create inputs with the number of rows larger
+ *  than the original input, which is required for some LAPACK/MAGMA methods.
+ * 3. `desired_batch_size` is used to create copies with the batch size
+ *  which is either the original batch size of the input, or its larger
+ *  broadcasted shape.
+ */
+inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
+    at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) {
+  nrows = (nrows == -1) ? src.size(-2) : nrows;
+  auto copy_sizes = desired_batch_sizes.has_value()
+    ? desired_batch_sizes.value().vec()
+    : IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
+  copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
+  const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
+  auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
+  copy.narrow(-2, 0, src.size(-2)).copy_(src);
+  return copy;
+}
+
+/*
+ * Given batches of matrices with arbitrary batch dim,
+ * computes the number of batches.
+ */
+inline int64_t batchCount(const Tensor& batched_matrices) {
+  int64_t result = 1;
+  for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
+    result *= batched_matrices.size(i);
+  }
+  return result;
+}
+
+// Computes the number of elements of a matrix in a batched matrix tensor
+inline int64_t matrixStride(const Tensor& batched_matrices) {
+  return batched_matrices.size(-1) * batched_matrices.size(-2);
+}
+
+// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
+inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
+  TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
+}
+inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
+  checkIsMatrix(self, f_name, arg_name);
+  TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
+              f_name,
+              ": ", arg_name, " must be batches of square matrices, "
+              "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
+}
+
+inline void checkInputsSolver(const Tensor& A,
+                                     const Tensor& B,
+                                     const bool left,
+                                     const char* const f_name) {
+  squareCheckInputs(A, f_name, "A");
+  checkIsMatrix(B, f_name, "B");
+  TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
+              f_name, ": Incompatible shapes of A and B for the equation ",
+              left ? "AX = B" : "XA = B",
+              " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
+}
+
+inline bool is_row_or_column_contiguous(const Tensor& t) {
+  // This could be made more general, similar to how it's checked in matmul, which would allow to
+  // elide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
+  // We choose to be conservative for simplicity
+  return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
+}
+
+inline TransposeType to_transpose_type(const bool contig, const bool conj) {
+  if (conj) {
+    if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
+    else {        return TransposeType::ConjTranspose; }
+  } else {
+    if (contig) { return TransposeType::NoTranspose; }
+    else {        return TransposeType::Transpose; }
+  }
+}
+
+
+// This function is designed to be used with linear algebra methods that minimize
+// L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
+// or the L2 norm (`lstsq`).
+// It is expected that `a` and `b` are contiguous tensors of column-major matrices
+// (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
+// with the following additional properties:
+//
+// 1. a.dim() == b.dim()
+// 2. a.shape[:-2] broadcasts over b.shape[:-2]
+// 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
+//
+// MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
+// is to be memory efficient, which means that if there exists an index i such that
+// a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
+// then instead of materializing copies of `a` in the broadcasted shape, we keep
+// a buffer copy of `a` along with flags that check whether specific batch dimension
+// indices for `a` were already accessed. If they were, we copy the data from the buffer
+// into `a`. The number of copies does not exceed
+// prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
+// and this value is attained by tensors with non-empty batch dimensions.
+//
+// func_t `f` is a callable that is being supplied with
+// scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
+// a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
+// and a_linear_batch_idx is an index in the 3d representation which corresponds to
+// the memory a_working_ptr points to, in other words:
+// a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr();
+// a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
+// its rank or singular values (see linalg_lstsq).
+template
+void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
+  IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
+  IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
+
+  auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
+  auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
+
+  TensorIterator iter = TensorIteratorConfig()
+    .set_check_mem_overlap(false)
+    .check_all_same_dtype(false)
+    .resize_outputs(false)
+    .add_output(b_linear_batch_idx)
+    .add_input(a_linear_batch_idx)
+    .build();
+
+  auto m = a.size(-2);
+  auto n = a.size(-1);
+  auto a_3d = a.view({batchCount(a), m, n});
+  auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
+
+  auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
+  Tensor a_buffer, a_was_accessed, a_buffer_3d;
+  std::function check_if_copy_needed_for_a
+    = [](int64_t /*a_curr_linear_batch_idx*/){};
+  if (a_broadcasts_over_b) {
+    a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
+      .copy_(a);
+    a_was_accessed = at::zeros(batchCount(a), at::kBool);
+    a_buffer_3d = a_buffer.view({batchCount(a), m, n});
+    check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
+      auto* a_was_accessed_flag = a_was_accessed
+        .select(0, a_curr_linear_batch_idx)
+        .data_ptr();
+      if (!(*a_was_accessed_flag)) {
+        *a_was_accessed_flag = true;
+      }
+      else {
+        a_3d.select(0, a_curr_linear_batch_idx)
+          .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
+      }
+    };
+  }
+
+  auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
+    auto* b_batch_idx_ptr = data[0];
+    auto* a_batch_idx_ptr = data[1];
+
+    for ([[maybe_unused]] const auto elem : c10::irange(nelems)) {
+      auto b_curr_linear_batch_idx =
+          *reinterpret_cast(b_batch_idx_ptr);
+      auto a_curr_linear_batch_idx = *reinterpret_cast(a_batch_idx_ptr);
+
+      check_if_copy_needed_for_a(a_curr_linear_batch_idx);
+
+      auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
+        .data_ptr();
+      auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
+        .data_ptr();
+      f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
+
+      b_batch_idx_ptr += strides[0];
+      a_batch_idx_ptr += strides[1];
+    }
+  };
+  iter.serial_for_each(loop, {0, batchCount(b)});
+}
+
+// Returns the epsilon value for floating types except half
+inline double _get_epsilon(const ScalarType& sc_type) {
+  switch (sc_type) {
+    case at::ScalarType::Float:
+      return static_cast(std::numeric_limits::epsilon());
+    case at::ScalarType::Double:
+      return std::numeric_limits::epsilon();
+    default:
+      TORCH_CHECK(false, "This function doesn't handle types other than float and double");
+  }
+}
+
+// Validates input shapes and devices
+// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
+inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
+  TORCH_CHECK(self.device() == A.device(),
+              "Expected b and A to be on the same device, but found b on ",
+              self.device(), " and A on ", A.device(), " instead.");
+
+  TORCH_CHECK(self.scalar_type() == A.scalar_type(),
+              "Expected b and A to have the same dtype, but found b of type ",
+              self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
+
+  TORCH_CHECK(A.size(-1) == A.size(-2),
+              "A must be batches of square matrices, "
+              "but they are ", A.size(-2), " by ", A.size(-1), " matrices");
+
+  TORCH_CHECK(A.size(-1) == self.size(-2),
+              "Incompatible matrix sizes for ", name, ": each A "
+              "matrix is ", A.size(-1), " by ", A.size(-1),
+              " but each b matrix is ", self.size(-2), " by ", self.size(-1));
+}
+
+inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
+  auto dtype = t.scalar_type();
+  TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
+              f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
+  if (!allow_low_precision_dtypes) {
+    TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
+                f_name, ": Low precision dtypes not supported. Got ", dtype);
+  }
+}
+
+
+// Checks if all the Tensors in a TensorList are of the same dimensions
+inline void checkAllSameDim(TensorList tensors, int64_t dim) {
+  for (auto &t : tensors) {
+    TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
+  }
+}
+
+inline std::tuple, std::vector> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
+  // broadcast the batch dimensions of arg1 and arg2.
+  IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
+  IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
+  std::vector expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
+
+  std::vector arg1_expand_size({expand_batch_portion});
+  arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
+
+  std::vector arg2_expand_size({expand_batch_portion});
+  arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
+  return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
+}
+
+inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
+  // If there's no name we assume we don't want to check the errors
+  if (name != nullptr) {
+    linearSolveCheckInputs(arg1, arg2, name);
+  }
+
+  auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
+
+  auto arg1_broadcasted  = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
+  auto arg2_broadcasted  = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
+  return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
+}
+
+inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
+  IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
+  IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
+  auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
+  return broadcasted_batch_sizes;
+}
+
+// Return a permutation with the given axes moved to the end.
+inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
+  const std::vector a = axes.vec();
+  const int64_t ndim = self.ndimension();
+  std::vector perm;
+
+  for (const auto i : c10::irange(ndim)) {
+    auto it = std::find(a.begin(), a.end(), i);
+    if (it == a.end()) {
+       perm.push_back(i);
+    }
+  }
+  for (auto i : a) {
+    perm.push_back(i);
+  }
+
+  TORCH_CHECK((int64_t)perm.size() == ndim,
+    "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
+
+  return self.permute(perm);
+}
+
+// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
+inline std::tuple _parse_qr_mode(std::string_view mode) {
+  bool compute_q;
+  bool reduced;
+  if (mode == "reduced") {
+    compute_q = true;
+    reduced = true;
+  } else if (mode == "complete") {
+    compute_q = true;
+    reduced = false;
+  } else if (mode == "r") {
+    compute_q = false;
+    reduced = true; // this is actually irrelevant in this mode
+  } else {
+      TORCH_CHECK(false, "qr received unrecognized mode '", mode,
+                  "' but expected one of 'reduced' (default), 'r', or 'complete'");
+  }
+  return std::make_tuple(compute_q, reduced);
+}
+
+// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
+inline std::tuple _compute_geometry_for_Q(
+    const Tensor& input,
+    bool reduced) {
+  int64_t m = input.size(-2), n = input.size(-1);
+  int64_t n_columns_q;
+
+  // We need to compute the required size of Q based on the `reduced` option
+  DimVector q_sizes(input.sizes());
+  if (!reduced && m > n) {
+    q_sizes[input.dim() - 1] = m;
+    n_columns_q = m;
+  } else {
+    q_sizes[input.dim() - 1] = n;
+    n_columns_q = std::min(m, n);
+  }
+  auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
+  return std::make_tuple(q_sizes, q_strides, n_columns_q);
+}
+
+inline bool svd_uses_cusolver(const Tensor& A) {
+  // if cusolver is available, it is used unconditionally
+  return A.is_cuda()
+         && at::globalContext().hasCuSOLVER()
+         && at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
+}
+
+
+// Function used instead of .to so that the original strides are retained
+// .to doesn't retain strides and make the output tensor contiguous
+inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
+  auto strided_to = at::empty_strided(original_tensor.sizes(),
+                                      original_tensor.strides(),
+                                      options);
+  strided_to.copy_(original_tensor);
+  return strided_to;
+}
+
+// Creates a dimension permutation array that can be given to `at::permute()`, which will shift
+// the two specified dimensions to the end of a tensor, without changing the order of
+// the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
+// placed just to the left of it.
+//
+// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
+// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
+// be `vec(0, 2, 1, 3)`.
+inline std::vector create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
+  TORCH_CHECK(
+    (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
+    "duplicate or invalid dimensions");
+  std::vector permutation(ndim);
+  int64_t cur_permuted_dim = 0;
+  for (const auto dim_ind : c10::irange(ndim)) {
+    if ((dim_ind != dim0) && (dim_ind != dim1)) {
+      permutation[cur_permuted_dim++] = dim_ind;
+    }
+  }
+  permutation[cur_permuted_dim++] = dim0;
+  permutation[cur_permuted_dim] = dim1;
+  return permutation;
+}
+
+// Creates a dimension permutation array that can be given to `at::permute()`, which
+// will reverse a given permutation.
+// The reverse permutation array is created by swapping the indices and their
+// associated values from the given permutation array.
+inline std::vector create_reverse_permutation(std::vector permutation) {
+  int64_t ndim = permutation.size();
+  std::vector reverse_permutation(ndim);
+  for (const auto dim_ind : c10::irange(ndim)) {
+    reverse_permutation[permutation[dim_ind]] = dim_ind;
+  }
+  return reverse_permutation;
+}
+
+// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
+// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
+inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
+  auto mn = std::min(m, n);
+  auto mx = std::max(m, n);
+  if (jobz == 'N') {
+#ifdef __APPLE__
+    // According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
+    return 7 * mn;
+#else
+    // These setting is valid for on LAPACK 3.6+
+    return 5 * mn;
+#endif
+  }
+  if (mx > 10 * mn) {
+    return 5 * mn * mn + 5 * mn;
+  }
+  return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
+}
+
+// This function checks whether the uplo argument input is valid
+// Allowed strings are "u", "U", "l", "L"
+inline void checkUplo(const std::string_view uplo) {
+  // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
+  char uplo_uppercase = static_cast(std::toupper(static_cast(uplo[0])));
+  TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
+    "Expected UPLO argument to be 'L' or 'U', but got ", uplo);
+}
+
+inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
+  TORCH_CHECK(
+      result.device() == input.device(),
+      fn_name,
+      ": Expected ", result_name, " and input tensors to be on the same device, but got ",
+      result_name, " on ", result.device(), " and input on ", input.device());
+}
+
+// Check the dtype of result and input tensors (for _out variants).
+// Most linear algebra functions have the same dtype for input and output
+// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
+// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
+// c10::canCast is used for checking the "safe copy" dtype requirements.
+inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
+  bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
+  TORCH_CHECK(
+      can_cast,
+      fn_name,
+      ": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
+      result_name, " with dtype ", result.scalar_type());
+}
+
+// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
+inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
+  bool can_cast = c10::canCast(result_type, out_type);
+  TORCH_CHECK(
+      can_cast,
+      fn_name,
+      ": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
+      out_name, " with dtype ", out_type);
+}
+
+inline void checkNotComplexTolerance(const Tensor& tol, const std::string_view f_name, const std::string_view tol_name) {
+  TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
+              f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
+}
+
+/*
+  Two types of 'other' tensors are supported when solving
+  a system of linear equations matmul(input, x) = other:
+  * 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
+  * 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
+  The original torch.solve supported only the matrix case, while NumPy works for both cases.
+  For the batched input we need to be able to distinguish them.
+  Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
+  This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
+*/
+inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
+  auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
+  bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
+  return vector_case;
+}
+
+/*
+  Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
+*/
+inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
+  TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
+  return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
+}
+
+class BroadcastLinearIndices {
+ private:
+  Tensor linear_indices_;
+  bool is_broadcasting_;
+
+ public:
+  BroadcastLinearIndices(
+      int64_t numel,
+      IntArrayRef original_shape,
+      IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
+    // The assumption is that the broadcast_shape is a materialized broadcast
+    // shape of the original_shape. We need to compute the linear indices
+    // compatible with the original_shape to access the elements in the original
+    // tensor corresponding to the broadcast tensor.
+    if (is_broadcasting_) {
+      linear_indices_ =
+          get_linear_indices(numel, original_shape, broadcast_shape);
+    }
+  }
+  int64_t operator()(int64_t broadcast_linear_index) {
+    return is_broadcasting_
+        ? linear_indices_.data_ptr()[broadcast_linear_index]
+        : broadcast_linear_index;
+  }
+};
+
+inline bool is_blas_compatible_column_major_order(const Tensor& input) {
+  IntArrayRef input_strides = input.strides();
+  IntArrayRef input_sizes = input.sizes();
+  auto ndim = input.dim();
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
+  if (ndim > 3) {
+    return input.transpose(-2, -1).is_contiguous();
+  }
+  auto leading_dimension = input_strides[ndim - 1];
+  auto rows = input_sizes[ndim - 2];
+  bool batch_stride_compatible = true;
+  if (ndim == 3) {
+    auto cols = input_sizes[ndim - 1];
+    batch_stride_compatible =
+        input_strides[ndim - 3] >= leading_dimension * cols;
+  }
+  return (input_strides[ndim - 2] == 1) &&
+      (leading_dimension >= std::max(1, rows)) &&
+      batch_stride_compatible;
+}
+
+inline bool is_blas_compatible_row_major_order(const Tensor& input) {
+  IntArrayRef input_strides = input.strides();
+  IntArrayRef input_sizes = input.sizes();
+  auto ndim = input.dim();
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
+  if (ndim > 3) {
+    return input.is_contiguous();
+  }
+  auto leading_dimension = input_strides[ndim - 2];
+  auto cols = input_sizes[ndim - 1];
+  bool batch_stride_compatible = true;
+  if (ndim == 3) {
+    auto rows = input_sizes[ndim - 2];
+    batch_stride_compatible =
+        input_strides[ndim - 3] >= leading_dimension * rows;
+  }
+  return (input_strides[ndim - 1] == 1) &&
+      (leading_dimension >= std::max(1, cols)) &&
+      batch_stride_compatible;
+}
+
+}  // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LossMulti.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LossMulti.h
new file mode 100644
index 0000000000000000000000000000000000000000..2a8b87e937a671c5fe04aea722502eda558f773b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LossMulti.h
@@ -0,0 +1,74 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+  inline void multilabel_margin_loss_shape_check(
+    int64_t& nframe,
+    int64_t& dim,
+    const int64_t& ndims,
+    const Tensor& input,
+    const Tensor& target) {
+    TORCH_CHECK(
+        (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
+        "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
+        input.sizes());
+
+    if (ndims <= 1) {
+      nframe = 1;
+      dim = ndims == 0 ? 1 : input.size(0);
+      TORCH_CHECK(
+          target.dim() <= 1 && target.numel() == dim,
+          "inconsistent target size: ", target.sizes(), " for input of size: ",
+          input.sizes());
+    } else {
+      nframe = input.size(0);
+      dim = input.size(1);
+      TORCH_CHECK(
+          target.dim() == 2 && target.size(0) == nframe &&
+          target.size(1) == dim,
+          "inconsistent target size: ", target.sizes(), " for input of size: ",
+          input.sizes());
+    }
+  }
+
+  inline void multi_margin_loss_shape_check(
+    int64_t& nframe,
+    int64_t& dim,
+    const int64_t& ndims,
+    const Tensor& input,
+    const Tensor& target,
+    const std::optional& weight) {
+    TORCH_CHECK(
+        (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
+        "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
+        input.sizes());
+
+    if (ndims <= 1) {
+      nframe = 1;
+      dim = ndims == 0 ? 1 : input.size(0);
+    } else {
+      nframe = input.size(0);
+      dim = input.size(1);
+    }
+
+    TORCH_CHECK(
+        target.dim() <= 1 && target.numel() == nframe,
+        "inconsistent target size, expected ", nframe, " but got ",
+        target.sizes());
+    if (weight && weight->defined()) {
+      TORCH_CHECK(
+          weight->dim() <= 1 && weight->numel() == dim,
+          "inconsistent weight size, expected ", dim, " but got ",
+          weight->sizes());
+    }
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MathBitsFallback.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MathBitsFallback.h
new file mode 100644
index 0000000000000000000000000000000000000000..f022190da40fe9898d3701a37b8fdbc9c7f09613
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MathBitsFallback.h
@@ -0,0 +1,162 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+
+#include 
+#endif
+
+namespace at::native {
+// This fallback should only be used for operations that are self inverse and have a corresponding tensor
+// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
+// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
+// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
+
+// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
+struct MathOpFallback {
+  MathOpFallback(DispatchKey key_, std::string op_name_) : key(key_), op_name(std::move(op_name_)) {}
+  virtual bool is_bit_set(const Tensor&) = 0;
+  void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
+    /*
+      Situations to handle:
+        1. Out-of-place operation.  Easy: materialize all inputs and
+          call it a day.
+        2. Inplace operation.  Desugar x.add_(2) into x.conj_().add_(2).conj_().
+          Materialize other inputs as in (1).
+        3. out= operation.  Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
+        Materialize other inputs as in (1).
+
+        It is important to be able to tell if we READ from an argument and if we
+        WRITE to an argument.  Conservative approach is to assume that we always
+        READ from an argument, but in out= operations you can skip
+        conjugating inputs on entry that never get used. In the current schema we
+        can't easily tell if the operation is in in-place or out= operation.
+
+        Note:
+        1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
+        2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
+           correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
+
+           If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
+           with these mutable inputs would read into wrong values in the following cases:
+           1. Non mutable inputs have their math bit set to false.
+           2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
+              with one or more mutable arg(s)) are cloned.
+           At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
+    */
+    const auto& arguments = op.schema().arguments();
+    const auto num_arguments = arguments.size();
+    const auto stack_start = stack->size() - num_arguments;
+
+    std::optional is_write;
+    for (const auto i : c10::irange(num_arguments)) {
+      // Three possible states:
+      // 1. alias_info has no value --> out-of-place operation
+      // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
+      // 3. alias_info does have a value, alias_info->is_write=False --> view operation
+      const AliasInfo* alias_info = arguments[i].alias_info();
+      if (alias_info != nullptr) {
+        if (is_write.has_value()) {
+          TORCH_CHECK(*is_write == alias_info->isWrite(),
+            "Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
+            op_name, " fallback doesn't work for operators with a mix "
+            "mutable and non-mutable inputs that alias with outputs, "
+            "this must be implemented manually.  "
+            "If you got this error on a core op, please report a bug to PyTorch.");
+        } else {
+          is_write = alias_info->isWrite();
+        }
+      }
+    }
+
+    if (is_write.has_value() && !*is_write) {
+      // We assume that view operators automatically handle the math bit
+      // correctly by propagating the dispatch key in key_set.
+      // This is not necessarily always right, so you should test these cases.
+      op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
+      return;
+    }
+
+    // Mutable inputs with math bit set to True and their clones
+    std::vector> mutable_inputs_with_their_clones;
+    for (const auto i : c10::irange(num_arguments)) {
+      auto& ivalue = (*stack)[stack_start + i];
+      if (!(ivalue.isTensor() || ivalue.isTensorList())) {
+        continue;
+      }
+      const auto& argument = arguments[i];
+      bool mut_arg = false;
+      if (argument.alias_info()) {
+        // Was already tested by is_write loop above
+        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
+        mut_arg = true;
+      }
+      if (ivalue.isTensor()) {
+        if (!is_bit_set(ivalue.toTensor())) {
+          continue;
+        }
+        auto tensor = std::move(ivalue).toTensor();
+        auto resolved_tensor = at::clone(tensor);
+        if (mut_arg) {
+          TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
+            op_name, "bit set to true.");
+          mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
+        }
+        (*stack)[stack_start + i] = std::move(resolved_tensor);
+      } else if (ivalue.isTensorList()) {
+        auto tensors = std::move(ivalue).toTensorList();
+        for(const auto j : c10::irange(tensors.size())) {
+          const auto& tensor = tensors[j];
+          if (!is_bit_set(tensor)) {
+            continue;
+          }
+          TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
+              op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
+              op.schema().name());
+          tensors[j] = at::clone(tensor);
+        }
+        (*stack)[stack_start + i] = std::move(tensors);
+      }
+    }
+
+    op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
+
+    TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
+
+    for (std::pair mut_tensors: mutable_inputs_with_their_clones) {
+      auto& mutable_input =  mut_tensors.first;
+      auto& cloned_mutable_input =  mut_tensors.second;
+      auto& ivalue = (*stack)[stack_start];
+      auto returned_output = std::move(ivalue).toTensor();
+
+      // sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
+      TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
+
+      // necessary for out= arg
+      at::native::resize_output(mutable_input, returned_output.sizes());
+
+      mutable_input.copy_(returned_output);
+      (*stack)[stack_start] = std::move(mutable_input);
+    }
+  }
+
+  virtual ~MathOpFallback() = default;
+
+  DispatchKey key;
+  std::string op_name;
+};
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MaxPooling.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MaxPooling.h
new file mode 100644
index 0000000000000000000000000000000000000000..dab605a789f9da5892c999d3534fcdca88e043bb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MaxPooling.h
@@ -0,0 +1,102 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+inline void check_max_pool1d(
+    const Tensor& self,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    IntArrayRef dilation,
+    bool ceil_mode) {
+
+  TORCH_CHECK(
+      self.dim() == 2 || self.dim() == 3,
+      "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
+  TORCH_CHECK(
+      kernel_size.size() == 1,
+      "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
+      kernel_size.size());
+  TORCH_CHECK(
+      stride.empty() || stride.size() == 1,
+      "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
+      stride.size());
+  TORCH_CHECK(
+      padding.size() == 1,
+      "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
+      padding.size());
+  TORCH_CHECK(
+      dilation.size() == 1,
+      "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
+      dilation.size());
+
+  // If stride=None then set it to kernel_size
+  if (stride.empty()) {
+    stride = kernel_size;
+  }
+
+  TORCH_CHECK(
+      kernel_size[0] > 0,
+      "max_pool1d() kernel_size must be greater than zero, but got ",
+      kernel_size[0]);
+  TORCH_CHECK(
+      stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
+  TORCH_CHECK(
+      padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
+  TORCH_CHECK(
+      padding[0] <= kernel_size[0] / 2,
+      "max_pool1d() padding should be at most half of kernel size, but got padding=",
+      padding[0],
+      " and kernel_size=",
+      kernel_size[0]);
+  TORCH_CHECK(
+      dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
+
+  const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
+  TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
+}
+
+// TODO(Heitor) Template by dimension
+struct PoolingParams1D {
+  int64_t NB; // Number of batches
+  int64_t NC; // Number of channels
+  int64_t IW; // Input width
+  int64_t OW; // Output width
+  int64_t KW; // Kernel width
+  int64_t SJ; // Column stride
+  int64_t PJ; // Column padding
+  int64_t DJ; // Column dilation
+
+  // Return index of input element for the given kernel and output index
+  inline int64_t index(int64_t kj, int64_t oj) const {
+    return oj * SJ + kj * DJ - PJ;
+  }
+
+  // Return index of first output within bounds for this kernel index
+  inline int64_t valid_output_start(int64_t kj) const {
+    int64_t ij = index(kj, 0);;
+    return ij < 0 ? at::divup(-ij, SJ) : 0;
+  }
+
+  // Return index one past last output within bounds for this kernel index
+  inline int64_t valid_output_end(int64_t kj) const {
+    int64_t ij = index(kj, OW - 1);
+    return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
+  }
+};
+
+using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
+
+DECLARE_DISPATCH(pooling_fn, max_pool1d_stub)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/NonEmptyUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/NonEmptyUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..8eff5eb7714571a6c66513b78e35d5dfd02aca90
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/NonEmptyUtils.h
@@ -0,0 +1,32 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+inline int64_t ensure_nonempty_dim(int64_t dim) {
+  return std::max(dim, 1);
+}
+
+inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
+  return t.dim() == 0 ? 1 : t.size(dim);
+}
+
+inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
+  return t.dim() == 0 ? 1 : t.stride(dim);
+}
+
+using IdxVec = std::vector;
+inline IdxVec ensure_nonempty_vec(IdxVec vec) {
+  if (vec.empty()) {
+    vec.push_back(1);
+  }
+  return vec;
+}
+
+}  // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Normalization.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Normalization.h
new file mode 100644
index 0000000000000000000000000000000000000000..34fcd64e33aea5d838a8d640fed0d54c0b36f6fa
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Normalization.h
@@ -0,0 +1,24 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
+DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub)
+
+enum class BatchNormBackend {
+  Native,
+  Cudnn,
+  Miopen,
+};
+
+TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
+
+}  // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Padding.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Padding.h
new file mode 100644
index 0000000000000000000000000000000000000000..ba9f7825c55ffbda03ac375531459b135240f70b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Padding.h
@@ -0,0 +1,68 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
+
+// reflection padding
+DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel)
+DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel)
+DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel)
+DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel)
+DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel)
+DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel)
+
+// replication padding
+DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel)
+DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel)
+DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel)
+DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel)
+DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel)
+DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel)
+
+namespace padding {
+
+template 
+inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
+
+  TORCH_CHECK(padding.size() == 2 * dim,
+      "padding size is expected to be ", 2 * dim,
+      ", but got: ", padding.size());
+
+  int input_dim = input.dim();
+
+  bool is_batch_mode = input_dim == (dim + 2);
+  bool is_non_batch_mode = input_dim == (dim + 1);
+
+  bool valid_batch_mode = is_batch_mode;
+  bool valid_non_batch_mode = is_non_batch_mode;
+
+  if (is_batch_mode) {
+    // allow batch size of 0-dim.
+    for (const auto d : c10::irange(1, input_dim)) {
+      valid_batch_mode = valid_batch_mode && input.size(d) != 0;
+    }
+  } else {
+    for (const auto d : c10::irange(0, input_dim)) {
+      valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
+    }
+  }
+
+  // allow empty batch size but not other dimensions.
+  TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
+      "Expected ", dim + 1, "D or ", dim + 2,
+      "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
+      input.sizes());
+}
+
+} // namespace padding
+
+} // at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/PixelShuffle.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/PixelShuffle.h
new file mode 100644
index 0000000000000000000000000000000000000000..856b015a35153bc45fb5320642d1c79fcf564f64
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/PixelShuffle.h
@@ -0,0 +1,53 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#include 
+#include 
+
+namespace at::native {
+
+inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
+  TORCH_CHECK(self.dim() >= 3,
+              "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
+              self.dim(), " dimension(s)");
+  TORCH_CHECK(upscale_factor > 0,
+              "pixel_shuffle expects a positive upscale_factor, but got ",
+              upscale_factor);
+  int64_t c = self.size(-3);
+  TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits::max() / upscale_factor,
+        "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
+  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
+  TORCH_CHECK(c % upscale_factor_squared == 0,
+              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
+              "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
+}
+
+inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
+  TORCH_CHECK(
+      self.dim() >= 3,
+      "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
+      self.dim(),
+      " dimension(s)");
+  TORCH_CHECK(
+      downscale_factor > 0,
+      "pixel_unshuffle expects a positive downscale_factor, but got ",
+      downscale_factor);
+  int64_t h = self.size(-2);
+  int64_t w = self.size(-1);
+  TORCH_CHECK(
+      h % downscale_factor == 0,
+      "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
+      h,
+      " is not divisible by ",
+      downscale_factor);
+  TORCH_CHECK(
+      w % downscale_factor == 0,
+      "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
+      w,
+      " is not divisible by ",
+      downscale_factor);
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Pool.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Pool.h
new file mode 100644
index 0000000000000000000000000000000000000000..b2d2a054df6c6b66146ef004596bf7749869365e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Pool.h
@@ -0,0 +1,366 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#pragma once
+
+namespace at::native {
+
+using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
+    int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
+using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
+
+DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel)
+DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel)
+
+// average pooling has same signature for forward and backward
+using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
+    int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional divisor_override);
+using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
+    int dW, int dH, int padW, int padH, bool count_include_pad, std::optional divisor_override);
+
+DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel)
+DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel)
+
+// average pooling has same signature for forward and backward
+using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input,
+    int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD,
+    int64_t padW, int64_t padH, int64_t padD, bool count_include_pad,
+    std::optional divisor_override);
+using avg_pool3d_backward_fn = void(*)(const Tensor& output, const Tensor& input,
+    int kW, int kH, int kD, int dW, int dH, int dD,
+    int padW, int padH, int padD, bool count_include_pad,
+    std::optional divisor_override);
+
+DECLARE_DISPATCH(avg_pool3d_fn, avg_pool3d_kernel)
+DECLARE_DISPATCH(avg_pool3d_backward_fn, avg_pool3d_backward_kernel)
+
+using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
+    int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
+using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
+
+DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel)
+DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel)
+namespace {
+
+template 
+inline dest_t
+safe_downcast(src_t v)
+{
+  TORCH_CHECK(std::numeric_limits::min() <= v && v <= std::numeric_limits::max(),
+              "integer out of range");
+
+  return static_cast(v);
+}
+
+template
+inline T pooling_output_shape_pad_lr(
+        T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
+        bool ceil_mode) {
+    T outputSize = div_rtn(
+        inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
+        (ceil_mode ? stride - 1 : 0), stride) + 1;
+    if (ceil_mode) {
+        // ensure that the last pooling starts inside the image
+        // needed to avoid problems in ceil mode
+        if ((outputSize - 1) * stride >= inputSize + pad_l) {
+          --outputSize;
+        }
+    }
+    return outputSize;
+}
+
+template
+inline T pooling_output_shape(
+      T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
+    TORCH_CHECK(stride != 0, "stride should not be zero");
+    TORCH_CHECK(pad >= 0,
+                "pad must be non-negative, but got pad: ", pad);
+    TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2,
+                "pad should be at most half of effective kernel size, but got pad=",
+                pad, ", kernel_size=", kernelSize, " and dilation=", dilation)
+    return pooling_output_shape_pad_lr(
+        inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
+}
+
+template 
+std::pair _pooling_same_mode_padding_lr(
+    T inputSize, T kernelSize, T stride, T dilation) {
+  // NOTE: with strides, the output shape is ceil(inputSize/stride)
+  auto total_padding = T(dilation) * (kernelSize - 1);
+
+  // Prefer symmetric padding if possible
+  if (stride > 2 && (total_padding % 2 == 1)) {
+    // The floor in the output size calculation gives us a little wiggle room
+    auto wiggle_room = inputSize % stride - 1;
+    if (wiggle_room > 0) {
+      total_padding = total_padding - 1;
+    }
+  }
+
+  auto left = total_padding / 2;
+  return {left, total_padding - left};
+}
+
+inline std::pair pooling_same_mode_padding_lr(
+    int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
+  return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
+}
+
+inline std::pair pooling_same_mode_padding_lr(
+    c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) {
+  return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation));
+}
+
+// AveragePool2d/DilatedMaxPool2d (forward)
+inline void
+pool2d_shape_check(
+  const Tensor& input,
+  int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, int64_t dilationH, int64_t dilationW,
+  int64_t nInputPlane,
+  int64_t inputHeight, int64_t inputWidth,
+  int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
+{
+  const int64_t ndim = input.ndimension();
+#ifndef STRIP_ERROR_MESSAGES
+  const int64_t nOutputPlane = nInputPlane;
+#endif
+
+  TORCH_CHECK(kW > 0 && kH > 0,
+              "kernel size should be greater than zero, but got ",
+              "kH: ", kH, " kW: ", kW);
+  TORCH_CHECK(dW > 0 && dH > 0,
+              "stride should be greater than zero, but got "
+              "dH: ", dH, " dW: ", dW);
+  TORCH_CHECK(dilationH > 0 && dilationW > 0,
+              "dilation should be greater than zero, but got ",
+              "dilationH: ", dilationH, " dilationW: ", dilationW);
+
+  bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
+  if (memory_format == at::MemoryFormat::ChannelsLast){
+    // Expect tensor in NHWC format and allow 0-dim only for N.
+    TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
+      "Expected 4D (batch mode) tensor expected for input with channels_last layout"
+      " with optional 0 dim batch size for input, but got: ", input.sizes());
+  } else {
+    TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
+      (ndim == 4 && valid_dims && input.size(3) != 0),
+      "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
+      input.sizes());
+  }
+
+  TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
+              "pad should be smaller than or equal to half of kernel size, but got ",
+              "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
+
+  TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
+              "Given input size: (",
+              nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
+              "Calculated output size: (",
+              nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
+              "Output size is too small");
+}
+
+// DilatedMaxPool2d (backward)
+inline void
+max_pool2d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  const Tensor& indices,
+  int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
+  int64_t nInputPlane,
+  int64_t inputHeight, int64_t inputWidth,
+  int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
+{
+  pool2d_shape_check(
+    input,
+    kH, kW, dH, dW, padH, padW, dilationH, dilationW,
+    nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
+
+  const int64_t ndim = input.ndimension();
+  const int64_t nOutputPlane = nInputPlane;
+
+  check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
+  check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
+  check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
+
+  check_dim_size(indices, ndim, ndim-3, nOutputPlane);
+  check_dim_size(indices, ndim, ndim-2, outputHeight);
+  check_dim_size(indices, ndim, ndim-1, outputWidth);
+
+  if (ndim == 4) {
+    const int64_t batchSize = input.size(0);
+    check_dim_size(gradOutput, ndim, 0, batchSize);
+    check_dim_size(indices, ndim, 0, batchSize);
+  }
+}
+
+// AveragePool2d (backward)
+inline void
+avg_pool2d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  int64_t /*nbatch*/,
+  int kH, int kW, int dH, int dW, int padH, int padW,
+  int64_t nInputPlane,
+  int64_t inputHeight, int64_t inputWidth,
+  int64_t outputHeight, int64_t outputWidth,
+  MemoryFormat memory_format)
+{
+  pool2d_shape_check(
+    input,
+    kH, kW, dH, dW, padH, padW, 1, 1,
+    nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
+    memory_format);
+
+  const int64_t ndim = input.ndimension();
+  const int64_t nOutputPlane = nInputPlane;
+
+  check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
+  check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
+  check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
+}
+
+// AveragePool3d/DilatedMaxPool3d (forward)
+inline void
+pool3d_shape_check(
+  const Tensor& input,
+  int64_t nslices,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int pT, int pH, int pW,
+  int dilationT, int dilationH, int dilationW,
+  int64_t itime, int64_t iheight, int64_t iwidth,
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char *fn_name,
+  bool check_input_size=false)
+{
+  const int64_t ndim = input.ndimension();
+
+  TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
+              "kernel size should be greater than zero, but got ",
+              "kT: ", kT, " kH: ", kH, " kW: ", kW);
+  TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
+              "stride should be greater than zero, but got ",
+              "dT: ", dT, " dH: ", dH, " dW: ", dW);
+  TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
+              "dilation should be greater than zero, but got ",
+              "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
+
+  TORCH_CHECK(ndim == 4 || ndim == 5,
+              fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
+
+  for (const auto i : c10::irange(ndim)) {
+    if (ndim == 5 && i == 0) {
+      // size of batch-dim can be 0.
+      continue;
+    }
+    TORCH_CHECK(
+        input.size(i) > 0,
+        fn_name,
+        ": Expected input's non-batch dimensions to have positive length,"
+        " but input has a shape of ",
+        input.sizes(),
+        " and non-batch dimension ",
+        input.size(i),
+        " has length zero!")
+  }
+
+  if (check_input_size) { // AveragePool3d
+    TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
+                "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
+                "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
+  }
+
+  TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
+              "pad should be smaller than or equal to half of kernel size, but got "
+              "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
+
+  TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
+              "Given input size: (",
+              nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
+              "Calculated output size: (",
+              nslices, "x", otime, "x", oheight, "x", owidth, "). ",
+              "Output size is too small");
+}
+
+inline void
+max_pool3d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  const Tensor& indices,
+  int64_t nslices,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int pT, int pH, int pW,
+  int dilationT, int dilationH, int dilationW,
+  int64_t itime, int64_t iheight, int64_t iwidth,
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char* fn_name)
+{
+  const int64_t ndim = input.ndimension();
+
+  pool3d_shape_check(
+    input,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    pT, pH, pW,
+    dilationT, dilationH, dilationW,
+    itime, iheight, iwidth,
+    otime, oheight, owidth, fn_name);
+
+  check_dim_size(gradOutput, ndim, ndim-4, nslices);
+  check_dim_size(gradOutput, ndim, ndim-3, otime);
+  check_dim_size(gradOutput, ndim, ndim-2, oheight);
+  check_dim_size(gradOutput, ndim, ndim-1, owidth);
+
+  check_dim_size(indices, ndim, ndim-4, nslices);
+  check_dim_size(indices, ndim, ndim-3, otime);
+  check_dim_size(indices, ndim, ndim-2, oheight);
+  check_dim_size(indices, ndim, ndim-1, owidth);
+}
+
+inline void
+avg_pool3d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  int64_t nslices,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int pT, int pH, int pW,
+  int64_t itime, int64_t iheight, int64_t iwidth,
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char *fn_name)
+{
+  const int64_t ndim = input.ndimension();
+
+  pool3d_shape_check(
+    input,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    pT, pH, pW,
+    1, 1, 1,
+    itime, iheight, iwidth,
+    otime, oheight, owidth,
+    fn_name, true);
+
+  check_dim_size(gradOutput, ndim, ndim-4, nslices);
+  check_dim_size(gradOutput, ndim, ndim-3, otime);
+  check_dim_size(gradOutput, ndim, ndim-2, oheight);
+  check_dim_size(gradOutput, ndim, ndim-1, owidth);
+}
+
+} // anonymous namespace
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RNN.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RNN.h
new file mode 100644
index 0000000000000000000000000000000000000000..f8cb3f41ed2016e2755b3734677dd2543259b5a5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RNN.h
@@ -0,0 +1,58 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool);
+using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool);
+using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool);
+using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool);
+
+DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub)
+DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub)
+DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub)
+DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub)
+DECLARE_DISPATCH(rnn_fn, gru_miopen_stub)
+DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub)
+DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub)
+DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub)
+DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub)
+DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub)
+DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub)
+DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub)
+DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub)
+DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub)
+DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub)
+DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub)
+DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub)
+
+inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) {
+  auto input_device = input.device();
+  auto input_dtype = input.scalar_type();
+
+  auto check_tensors = [&](const std::string& name, const Tensor& t) {
+    if (!t.defined()) return;
+    auto t_device = t.device();
+    TORCH_CHECK(input_device == t_device,
+             "Input and ", name, " tensors are not at the same device, found input tensor at ",
+             input_device, " and ", name, " tensor at ", t_device);
+    if (check_dtype) {
+      auto t_dtype = t.scalar_type();
+      TORCH_CHECK(input_dtype == t_dtype,
+               "Input and ", name, " tensors are not the same dtype, found input tensor with ",
+               input_dtype, " and ", name, " tensor with ", t_dtype);
+    }
+  };
+
+  for (const auto& h : hiddens) check_tensors("hidden", h);
+  for (const auto& p : params) check_tensors("parameter", p);
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceOpsUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceOpsUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..8dff19ab88c6cf2b203212a8087028af42756f58
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceOpsUtils.h
@@ -0,0 +1,473 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#endif
+
+namespace at::native {
+
+// Maximum and minimum possible scalar values, including infinities
+template 
+constexpr scalar_t upper_bound() {
+  using lim = std::numeric_limits;
+  return lim::has_infinity ? lim::infinity() : lim::max();
+}
+
+template 
+constexpr scalar_t lower_bound() {
+  using lim = std::numeric_limits;
+  return lim::has_infinity ? -lim::infinity() : lim::lowest();
+}
+
+inline Tensor restride_dim(
+  const Tensor& src, int64_t dim,
+  IntArrayRef replacement_shape
+) {
+  auto strides = ensure_nonempty_vec(src.strides().vec());
+  strides[dim] = 0;
+  return src.as_strided(replacement_shape, strides);
+}
+
+inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
+                                int64_t dim) {
+  IntArrayRef self_sizes = self.sizes();
+  std::vector result_sizes;
+  result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
+  result_sizes[dim] = 1;
+  result.resize_(result_sizes);
+}
+
+inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
+                                      const Scalar& ident, int64_t dim, bool keepdim) {
+  if (self.numel() == 1 && self.ndimension() == 0) {
+    result.resize_({});
+    result.fill_(self);
+    return true;
+  }
+  // Return identity
+  if (self.numel() == 0) {
+    _dimreduce_setup(result, self, dim);
+    result.fill_(ident);
+    if (!keepdim) result.squeeze_(dim);
+    return true;
+  }
+  return false;
+}
+
+inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
+                                               int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
+  if (self.numel() == 1 && self.ndimension() == 0) {
+    result.resize_({});
+    result.fill_(self);
+    return true;
+  }
+
+  return false;
+}
+
+inline std::optional _allreduce_return_trivial(
+    const Tensor& self,
+    const Scalar& ident) {
+  // Return identity
+  if (self.numel() == 0) {
+    return at::scalar_tensor(ident, self.options());
+  }
+  return std::nullopt;
+}
+
+#define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
+{ \
+  TORCH_CHECK(\
+    out.option() == self.option(),\
+    "expected ", #option, " ",\
+    self.option(),\
+    " but found ", out.option())\
+}
+
+inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
+  OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
+  OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
+  OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
+}
+
+inline Tensor integer_upcast(const Tensor& self, std::optional dtype) {
+  ScalarType scalarType = self.scalar_type();
+  TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
+  ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
+  return self.toType(upcast_scalarType);
+}
+
+using DimMask = TensorIterator::DimMask;
+
+inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
+  if (opt_dims.has_value()) {
+    return DimVector(opt_dims.value());
+  } else {
+    std::vector all_dims(ndim);
+    std::iota(all_dims.begin(), all_dims.end(), 0);
+    return DimVector(all_dims);
+  }
+}
+
+inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
+  DimMask mask;
+  if (opt_dims.has_value()) {
+    auto dims = opt_dims.value();
+    if (dims.empty() && !allow_empty_dims) {
+      mask = DimMask().flip();
+    } else {
+      mask = at::dim_list_to_bitset(dims, ndim);
+    }
+  } else {
+    mask = DimMask().flip();
+  }
+  return mask;
+}
+
+inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
+  auto shape = DimVector(self.sizes());
+  for (int dim = shape.size() - 1; dim >= 0; dim--) {
+    if (mask[dim]) {
+      if (keepdim) {
+        shape[dim] = 1;
+      } else {
+        shape.erase(shape.begin() + dim);
+      }
+    }
+  }
+  return shape;
+}
+
+inline void resize_reduction_result(
+    Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
+    ScalarType /*dtype*/)
+{
+  auto shape = shape_from_dim_mask(self, mask, keepdim);
+  TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
+  at::native::resize_output(result, shape);
+}
+
+inline Tensor create_reduction_result(
+  const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
+) {
+  DimMask mask = make_dim_mask(dim, self.dim());
+  auto shape = shape_from_dim_mask(self, mask, keepdim);
+  return at::empty(shape, self.options().dtype(dtype));
+}
+
+inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
+  if (keepdim) {
+    return result;
+  }
+  auto shape = DimVector(result.sizes());
+  auto stride = DimVector(result.strides());
+  for (const auto dim : c10::irange(ndim)) {
+    if (mask[dim]) {
+      shape.insert(shape.begin() + dim, 1);
+      stride.insert(stride.begin() + dim, 0);
+    }
+  }
+  return result.as_strided(shape, stride);
+}
+
+inline TensorIterator make_reduction(
+    const char* name, Tensor& result, const Tensor& self,
+    at::OptionalIntArrayRef dim_opt,
+    bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
+  // check that result type and dtype match if provided
+  TORCH_CHECK(
+      !result.defined() || result.scalar_type() == out_dtype,
+      name, ": provided dtype must match dtype of result. Got ",
+      toString(result.scalar_type()),
+      " and ",
+      toString(out_dtype),
+      ".");
+  // dim={} performs an all-reduce, same as dim=None
+  IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
+  int64_t ndim = self.dim();
+  auto mask = make_dim_mask(dim, ndim);
+  resize_reduction_result(result, self, mask, keepdim, out_dtype);
+  auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
+  namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
+  if (self.scalar_type() == in_dtype) {
+    return TensorIterator::reduce_op(viewed_result, self);
+  }
+  return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
+}
+
+[[maybe_unused]] inline TensorIterator make_reduction(
+    const char* name,
+    Tensor& result,
+    const Tensor& self,
+    at::OptionalIntArrayRef dim,
+    bool keepdim,
+    ScalarType out_dtype) {
+  // special case for type promotion in mixed precision, improves computational
+  // efficiency.
+  // not generalize this to common mismatched input/output types to avoid cross
+  // product of templated kernel launches.
+  const bool gpu_lowp_to_f32 = (
+        (self.is_cuda() || self.is_xpu()) && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
+  auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
+                   : self.is_complex() ? c10::toComplexType(out_dtype)
+                                       : out_dtype;
+  return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
+}
+
+inline TensorIterator make_reduction(
+    const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
+    at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
+    ScalarType dtype2) {
+  // check that result type and dtype match if provided
+  TORCH_CHECK(
+    (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
+    name, ": provided dtype must match dtype of result. Got ",
+    toString(result1.scalar_type()), toString(result2.scalar_type()),
+    " and ",
+    toString(dtype1), toString(dtype2),
+    ".");
+
+  // dim={} performs an all-reduce, same as dim=None
+  auto dim = dim_opt.value_or(IntArrayRef{});
+  int64_t ndim = self.dim();
+  DimMask mask = make_dim_mask(dim, ndim);
+  resize_reduction_result(result1, self, mask, keepdim, dtype1);
+  auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
+
+  resize_reduction_result(result2, self, mask, keepdim, dtype2);
+  auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
+
+  namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
+  namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
+
+  // special case for type promotion in mixed precision, improves computational
+  // efficiency.
+  // We don't generalize this to common mismatched input/output types to avoid cross
+  // product of templated kernel launches.
+  if (self.scalar_type() == dtype1 ||
+      (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
+    return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
+  }
+  return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
+}
+
+[[maybe_unused]] inline TensorIterator make_reduction(
+    const char* name,
+    Tensor& result1,
+    Tensor& result2,
+    const Tensor& self,
+    at::OptionalIntArrayRef dim,
+    bool keepdim,
+    ScalarType dtype) {
+  return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
+}
+
+inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
+  if (self.ndimension() == 0) {
+    TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
+      ": Expected reduction dim -1 or 0 for scalar but got ", dim);
+  }
+  else {
+    TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
+      ": Expected reduction dim ", dim, " to have non-zero size.");
+  }
+}
+
+inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
+  TORCH_CHECK(
+    !dim.empty(),
+      fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
+        "Specify the reduction dim with the 'dim' argument.");
+  for (const int64_t d : dim) {
+    zero_numel_check_dims(self, d, fn_name);
+  }
+}
+
+inline std::vector get_zero_numel_tensor_size(
+    const Tensor& self,
+    const int64_t dim,
+    const bool keepdim,
+    const char* fn_name) {
+  TORCH_INTERNAL_ASSERT(self.numel() == 0,  fn_name, ": Expected self.numel() == 0.");
+  zero_numel_check_dims(self, dim, fn_name);
+  std::vector sizes;
+  if (keepdim) {
+    sizes = self.sizes().vec();
+    sizes[dim] = 1;
+  }
+  else {
+    for (const auto d : c10::irange(self.dim())) {
+      if (d != dim) {
+        sizes.push_back(self.sizes()[d]);
+      }
+    }
+  }
+  return sizes;
+}
+
+// Resize the result tensor and indices when result.numel() == 0 depending on values of
+// dim and keepdim for returning tensors containing reduction results.
+// This function should be called when you are reducing a zero-numel tensor and want to
+// resize the output and return it. This function exists for resizing zero-numel
+// tensors when the size of the reduction dimension is non-zero.
+[[maybe_unused]] inline void zero_numel_tensor_resize(
+    Tensor& result,
+    Tensor& result_indices,
+    const Tensor& self,
+    const int64_t dim,
+    const bool keepdim,
+    const char* fn_name) {
+  auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
+  at::native::resize_output(result, sizes);
+  at::native::resize_output(result_indices, sizes);
+}
+
+inline ScalarType get_dtype_from_self(
+    const Tensor& self,
+    const std::optional& dtype,
+    bool promote_integers) {
+  if (dtype.has_value()) {
+    return dtype.value();
+  }
+  ScalarType src_type = self.scalar_type();
+  if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
+    return kLong;
+  }
+  return src_type;
+}
+
+inline ScalarType get_dtype_from_result(Tensor& result, std::optional dtype) {
+  TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
+  if (dtype.has_value()) {
+    return dtype.value();
+  } else {
+    return result.scalar_type();
+  }
+}
+
+
+} // namespace at::native
+
+namespace at::meta {
+
+[[maybe_unused]] inline DimVector get_reduction_shape(
+    const Tensor& self,
+    IntArrayRef dims,
+    bool keepdim,
+    bool allow_empty_dims = false) {
+  auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
+  return native::shape_from_dim_mask(self, mask, keepdim);
+}
+
+inline void resize_reduction(
+    impl::MetaBase& meta,
+    const Tensor& self,
+    OptionalIntArrayRef opt_dims,
+    bool keepdim,
+    ScalarType out_dtype,
+    bool allow_empty_dims=false) {
+  DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
+  maybe_wrap_dims(dims_, self.dim());
+  auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
+  if (self.layout() == kStrided) {
+    meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
+  } else if (shape.empty()) {
+    meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype).layout(kStrided));
+  } else {
+    TORCH_CHECK(false, "resize_reduction: support for output with ", self.layout(), " layout is not implemented yet");
+  }
+  namedinference::propagate_names_for_reduction(
+      meta.maybe_get_output(), self, dims_, keepdim);
+}
+
+inline void resize_reduction_with_indices(
+    impl::MetaBase& meta,
+    const Tensor& self,
+    IntArrayRef dims,
+    bool keepdim,
+    ScalarType out_dtype) {
+  DimVector dims_(dims);
+  maybe_wrap_dims(dims_, self.dim());
+  auto shape = get_reduction_shape(self, dims_, keepdim);
+  meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
+  meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
+  namedinference::propagate_names_for_reduction(
+      meta.maybe_get_output(0), self, dims_, keepdim);
+  namedinference::propagate_names_for_reduction(
+      meta.maybe_get_output(1), self, dims_, keepdim);
+}
+
+inline TensorIterator make_reduction(
+    const Tensor& self,
+    const Tensor& result,
+    OptionalIntArrayRef opt_dims,
+    bool keepdim,
+    ScalarType in_dtype) {
+  int64_t ndim = self.dim();
+  auto mask = at::native::make_dim_mask(opt_dims, ndim);
+  auto viewed_result =
+      at::native::review_reduce_result(result, ndim, mask, keepdim);
+  if (self.scalar_type() == in_dtype) {
+    return TensorIterator::reduce_op(viewed_result, self);
+  }
+  return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
+}
+
+inline TensorIterator make_reduction(
+    const Tensor& self,
+    const Tensor& result1,
+    const Tensor& result2,
+    IntArrayRef dims,
+    bool keepdim,
+    ScalarType dtype1,
+    ScalarType /*dtype2*/) {
+  int64_t ndim = self.dim();
+  auto mask = at::native::make_dim_mask(dims, ndim);
+  auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
+  auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
+  // special case for type promotion in mixed precision, improves computational efficiency.
+  // We don't generalize this to common mismatched input/output types to avoid cross product
+  // of templated kernel launches.
+  if (self.scalar_type() == dtype1 ||
+      (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
+    return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
+  }
+  return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
+}
+
+[[maybe_unused]] inline TensorIterator make_reduction_from_out_ty(
+    const Tensor& self,
+    const Tensor& result,
+    OptionalIntArrayRef opt_dims,
+    bool keepdim,
+    ScalarType out_dtype) {
+  // special case for type promotion in mixed precision, improves computational
+  // efficiency.
+  // not generalize this to common mismatched input/output types to avoid cross
+  // product of templated kernel launches.
+  const bool gpu_lowp_to_f32 =
+      (self.is_cuda() &&
+       (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
+       out_dtype == kFloat);
+  auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
+  return make_reduction(self, result, opt_dims, keepdim, in_dtype);
+}
+
+} // namespace at::meta
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Repeat.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Repeat.h
new file mode 100644
index 0000000000000000000000000000000000000000..1f4f80a76d0a1b93be14296a15ec63b7ce2d5a91
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Repeat.h
@@ -0,0 +1,53 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#endif
+
+namespace at::native {
+
+template <
+    typename index_t,
+    void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
+static inline Tensor repeat_interleave_common(
+    const Tensor& repeats,
+    std::optional output_size) {
+  TORCH_CHECK(
+      repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
+  TORCH_CHECK(
+      repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
+      "repeats has to be Long or Int tensor");
+  if (repeats.size(0) == 0) {
+    return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  }
+  Tensor repeats_ = repeats.contiguous();
+  Tensor cumsum = repeats.cumsum(0);
+  int64_t total = 0;
+  if (output_size.has_value()) {
+    total = output_size.value();
+  } else {
+    total = cumsum[-1].item();
+    TORCH_CHECK(
+        (repeats >= 0).all().item(), "repeats can not be negative");
+  }
+
+  Tensor result = at::empty({total}, repeats.options());
+  const index_t* repeat_ptr = repeats_.const_data_ptr();
+  const int64_t* cumsum_ptr = cumsum.const_data_ptr();
+  index_t* result_ptr = result.data_ptr();
+  compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
+  return result;
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Resize.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Resize.h
new file mode 100644
index 0000000000000000000000000000000000000000..d3af99da67372aec3597b5dda7c5c0df14545884
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Resize.h
@@ -0,0 +1,210 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include 
+
+
+namespace at::native {
+
+// TODO: make all operations that resize given outputs use this function
+//   for consistency and maintainability.
+//   Some operations like `cat` might not be able to make the use of
+//   resize_output directly. For more details to understand how it works in `cat`,
+//   see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
+// Resizes outputs
+// Functions accepting output tensors, like with the "out" kwarg, should
+//   call this function to handle resizing their output tensor.
+// Issues a warning if the output tensor has one or more elements and
+//   needs resizing
+// NOTE: In the future the warning will become an error
+// Returns a bool saying whether or not the resize actually happened or not
+TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
+// WARNING: Do NOT call this directly. If you are resizing an output and want
+// to support dynamic shapes call at::resize__symint and resize_output_check_symint.
+// For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
+TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
+
+// Utility for resize_output
+//  Returns a bool saying resize should happen or not and
+//  raises a warning if resizing for one or more elements
+TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
+TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
+
+TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
+TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
+TORCH_API void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& size_bytes);
+
+inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
+  // It does not make sense to try to resize a storage
+  // to hold 0 elements, and this can break
+  // if storage_offset is positive but
+  // new_size is 0, so just bail in that case
+  // (same comment is in cuda/Resize.h)
+  if (self->numel() == 0) {
+    return;
+  }
+
+  const Storage& storage = self->unsafe_storage();
+  if (!storage) {
+    auto new_storage = c10::make_intrusive(
+        StorageImpl::use_byte_size_t(),
+        new_size_bytes,
+        c10::GetCPUAllocator(),
+        true);
+    self->set_storage_keep_dtype(std::move(new_storage));
+  } else if (new_size_bytes > storage.nbytes()) {
+    resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
+  }
+}
+
+TORCH_API TensorImpl* resize_impl_cpu_(
+    TensorImpl* self,
+    IntArrayRef size,
+    at::OptionalIntArrayRef stride,
+    bool resize_storage = true);
+
+template 
+T maybe_convert_symint(c10::SymInt) = delete;
+
+template <>
+inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
+
+template <>
+inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
+
+template 
+inline void checkInBoundsForStorage(
+    ArrayRef size,
+    ArrayRef stride,
+    T storage_offset,
+    const caffe2::TypeMeta& data_type,
+    const Storage& new_storage) {
+  T storage_size_bytes, storage_size_plus_offset_bytes;
+  if (stride.data()) {
+    storage_size_bytes =
+        at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
+    storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
+        size, stride, data_type.itemsize(), storage_offset);
+  } else {
+    storage_size_bytes =
+        at::detail::computeStorageNbytesContiguous(size, data_type.itemsize());
+    storage_size_plus_offset_bytes = at::detail::computeStorageNbytesContiguous(
+        size, data_type.itemsize(), storage_offset);
+  }
+  // It's ok to always evaluate to False for this early return for SymInts because
+  // (1) maybe_convert_symint below only installs guard for int64_t case
+  // (2) we check for this condition in the TORCH_MAYBE_SYM_CHECK below
+  if (TORCH_GUARD_OR_FALSE(sym_eq(storage_size_bytes, 0))) {
+    // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
+    return;
+  }
+  T new_storage_size_bytes = maybe_convert_symint(new_storage.sym_nbytes());
+  TORCH_MAYBE_SYM_CHECK(
+      sym_eq(storage_size_bytes, 0) || sym_le(storage_size_plus_offset_bytes, new_storage_size_bytes),
+      "setStorage: sizes ",
+      size,
+      ", strides ",
+      stride,
+      ","
+      " storage offset ",
+      storage_offset,
+      ", and itemsize ",
+      data_type.itemsize(),
+      " requiring a storage size of ",
+      storage_size_plus_offset_bytes,
+      " are out of bounds for storage of size ",
+      new_storage_size_bytes);
+}
+
+template 
+inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
+                                   ArrayRef size, ArrayRef stride, bool check_offset_in_bounds = true) {
+  // FIXME: stride should be optional
+  if (stride.data()) {
+    TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
+                                              ") and stride length (", stride.size(), ")");
+  }
+
+#ifdef DEBUG
+  TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
+#endif
+
+  // storageOffset
+  TORCH_CHECK(
+    TORCH_GUARD_OR_TRUE(sym_ge(storage_offset, 0)), "Tensor: invalid storage offset ", storage_offset);
+
+  // set_storage_{device} (except set_storage_meta__symint)
+  // will (unsafely) set the storage offset and then call resize_impl that
+  // handles resizing the storage However, resize_impl will only resize the
+  // storage if the sizes/strides changed. For the case that the sizes/strides
+  // remain unchanged, the storage offset is not properly validated, so we do
+  // that here.
+  if (check_offset_in_bounds) {
+    auto result_tensor_impl = result.unsafeGetTensorImpl();
+    bool size_unchanged = result_tensor_impl->generic_sizes() == size;
+    bool stride_unchanged = stride.data()
+        ? result_tensor_impl->generic_strides() == stride
+        : true;
+    if (size_unchanged && stride_unchanged) {
+      checkInBoundsForStorage(
+          size, stride, storage_offset, result.dtype(), storage);
+    }
+  }
+
+  // storage: note this can't be replaced with result.set_(storage) as the semantics of that
+  // function is to set the tensor size to be equal to the size of the storage.
+  if (!result.storage().is_alias_of(storage)) {
+    // Caffe2 might have tensors whose storages are null, but we
+    // don't allow it in PyTorch.
+    TORCH_INTERNAL_ASSERT(storage);
+    TORCH_INTERNAL_ASSERT(result.storage());
+
+    // We used to allow this, but this breaks device caching.
+    // Let's put an actual error message for this one.
+    TORCH_CHECK(result.storage().device() == storage.device(),
+                "Attempted to set the storage of a tensor on device \"", result.storage().device(),
+                "\" to a storage on different device \"", storage.device(),
+                "\".  This is no longer allowed; the devices must match.");
+    result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
+  }
+}
+
+/**
+ * Set self's sizes, strides, and storage_offset.
+ * (size, stride, storage_offset) must be in bounds for self's storage.
+ */
+template 
+inline void setStrided(
+    const Tensor& self,
+    ArrayRef size,
+    ArrayRef stride,
+    T storage_offset) {
+  TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
+  for (const auto& val : stride) {
+    TORCH_CHECK(val >= 0,
+                "as_strided: Negative strides are not supported at the moment, "
+                "got strides: ", stride);
+  }
+
+  auto* self_ = self.unsafeGetTensorImpl();
+  checkInBoundsForStorage(
+      size, stride, storage_offset, self_->dtype(), self_->storage());
+
+  /* storage offset */
+  TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
+  self_->set_sizes_and_strides(size, stride, storage_offset);
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ScatterGatherChecks.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ScatterGatherChecks.h
new file mode 100644
index 0000000000000000000000000000000000000000..9d7e4c319a4e6500947c85143fd20fe66e7380ec
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ScatterGatherChecks.h
@@ -0,0 +1,133 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+namespace {
+
+// checks whether index.dtype == int64
+// and self.dtype == src.dtype if src is a Tensor
+inline void scatter_gather_dtype_check(
+  const std::string& method_name,
+  const Tensor& self,
+  const Tensor& index,
+  const std::optional& src_opt = std::nullopt
+) {
+  if (index.numel() != 0) {
+    TORCH_CHECK(
+      index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int,
+      method_name, "(): Expected dtype int32/int64 for index"
+    );
+  }
+
+  if (src_opt.has_value()) {
+    const auto& src = src_opt.value();
+    TORCH_CHECK(
+      self.scalar_type() == src.scalar_type(),
+      method_name, "(): Expected self.dtype to be equal to src.dtype"
+    );
+  }
+}
+
+// Used for `gather`-like methods
+// Note: self means the input tensor here
+// Test:
+// 1. index.size(d) <= self.size(d) for all d != dim
+// 2. index.dim() == self.dim()
+inline void gather_shape_check(const Tensor& self, int64_t dim,
+  const Tensor& index
+) {
+  auto self_dims = ensure_nonempty_dim(self.dim());
+  TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
+    "Index tensor must have the same number of dimensions as input tensor"
+  );
+
+  for (const auto i : c10::irange(self_dims)) {
+    if (i != dim) {
+      TORCH_CHECK(
+        ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
+        "Size does not match at dimension ", i,
+        " expected index ", index.sizes(),
+        " to be no larger than self ", self.sizes(),
+        " apart from dimension ", dim
+      );
+    }
+  }
+}
+
+// Used for `scatter` and `scatter_add`
+// Tests:
+//  1. index.size(d) <= self.size(d) for all d != dim
+//  2. index.size(d) <= src.size(d) for all d if src is a Tensor
+//  3. index.dim() == self.dim() == src.dim()
+inline void scatter_shape_check(
+  const Tensor& self, int64_t dim, const Tensor& index,
+  const std::optional& src_opt = std::nullopt
+) {
+  if (index.numel() == 0) return;
+  TORCH_CHECK(
+    ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
+    "Index tensor must have the same number of dimensions as self tensor"
+  );
+
+  bool is_wrong_shape = false;
+  int64_t self_dims = ensure_nonempty_dim(self.dim());
+
+  //  Check: index.size(d) <= self.size(d) for all d != dim
+  for (const auto d : c10::irange(self_dims)) {
+    int64_t index_d_size = ensure_nonempty_size(index, d);
+    if (d == dim) continue;
+    if (index_d_size > ensure_nonempty_size(self, d)) {
+      is_wrong_shape = true;
+      break;
+    }
+  }
+
+  //  Check: index.size(d) <= src.size(d) for all d if src is Tensor
+  if (!is_wrong_shape && src_opt.has_value()) {
+    const auto& src = src_opt.value();
+    for (const auto d : c10::irange(self_dims)) {
+      int64_t index_d_size = ensure_nonempty_size(index, d);
+      if (index_d_size > ensure_nonempty_size(src, d)) {
+        is_wrong_shape = true;
+        break;
+      }
+    }
+  }
+
+  if (src_opt.has_value()) {
+    const auto& src = src_opt.value();
+
+    TORCH_CHECK(
+      ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
+      "Index tensor must have the same number of dimensions as src tensor"
+    );
+
+    TORCH_CHECK(!is_wrong_shape,
+      "Expected index ", index.sizes(),
+      " to be no larger than self ", self.sizes(),
+      " apart from dimension ", dim,
+      " and to be no larger size than src ", src.sizes()
+    );
+  }
+  else {
+    TORCH_CHECK(!is_wrong_shape,
+      "Expected index ", index.sizes(),
+      " to be no larger than self ", self.sizes(),
+      " apart from dimension ", dim
+    );
+  }
+}
+
+} // anonymous namespace
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Sorting.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Sorting.h
new file mode 100644
index 0000000000000000000000000000000000000000..753430204919b2ea67c63f74fe2efdc6dc859538
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Sorting.h
@@ -0,0 +1,33 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at::native {
+
+enum class QUANTILE_INTERPOLATION_MODE : uint8_t {
+  LINEAR,
+  LOWER,
+  HIGHER,
+  MIDPOINT,
+  NEAREST
+};
+
+using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool);
+using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool);
+
+DECLARE_DISPATCH(sort_fn, sort_stub)
+DECLARE_DISPATCH(topk_fn, topk_stub)
+
+void _fill_indices(const TensorBase &indices, int64_t dim);
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SortingUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SortingUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..badeba5f6fbffd30e06eb09be18f2e3d16d8e6d5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SortingUtils.h
@@ -0,0 +1,93 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+
+// ensure we get good values and indices for kthvalue, mode
+// this will always be with the reducing dim as 1-d
+inline void _reduction_with_indices_allocate_or_resize_output(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t dim_,
+    bool keepdim) {
+  int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
+  auto result_sizes = self.sizes().vec();
+  if (!result_sizes.empty()) {
+    result_sizes[dim] = 1;
+  }
+  if (values.defined()) {
+    TORCH_CHECK(
+        self.options().type_equal(values.options()),
+        "output values must be of same type as input");
+    if (!keepdim && values.dim() == self.dim() - 1) {
+      // unsqueeze to preserve passed in noncontiguous tensor in resize
+      values.unsqueeze_(dim);
+    }
+    resize_output(values, result_sizes);
+  } else {
+    values = at::empty(result_sizes, self.options());
+  }
+  if (indices.defined()) {
+    TORCH_CHECK(
+        indices.dtype() == kLong, "output indices must be of scalar type Long");
+    TORCH_CHECK(
+        indices.device() == self.device(),
+        "output indices must be on same device as input");
+    if (!keepdim && indices.dim() == self.dim() - 1) {
+      // unsqueeze to preserve passed in noncontiguous tensor in resize
+      indices.unsqueeze_(dim);
+    }
+    resize_output(indices, result_sizes);
+  } else {
+    indices = at::empty(result_sizes, self.options().dtype(kLong));
+  }
+}
+
+// ensure we get good values and indices for topk
+inline void _allocate_or_resize_output_with_indices(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t dim_,
+    int64_t k) {
+  int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
+  auto result_sizes = self.sizes().vec();
+  if (!result_sizes.empty()) {
+    result_sizes[dim] = k;
+  }
+  if (values.defined()) {
+    TORCH_CHECK(
+        self.options().type_equal(values.options()),
+        "output values must be of same type as input");
+    values.resize_(result_sizes);
+  } else {
+    values = at::empty(result_sizes, self.options());
+  }
+  if (indices.defined()) {
+    TORCH_CHECK(
+        indices.dtype() == kLong, "output indices must be of scalar type Long");
+    TORCH_CHECK(
+        indices.device() == self.device(),
+        "output indices must be on same device as input");
+    indices.resize_(result_sizes);
+  } else {
+    indices = at::empty(result_sizes, self.options().dtype(kLong));
+  }
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SpectralOpsUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SpectralOpsUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..5d7666dc6577133a8bfc7bd03f9b2190e6ffc91d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SpectralOpsUtils.h
@@ -0,0 +1,89 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+// Normalization types used in _fft_with_size
+enum class fft_norm_mode {
+  none,       // No normalization
+  by_root_n,  // Divide by sqrt(signal_size)
+  by_n,       // Divide by signal_size
+};
+
+// NOTE [ Fourier Transform Conjugate Symmetry ]
+//
+// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
+// assuming X is the transformed K-dimensional signal, we have
+//
+//     X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
+//
+//       where j_k  = (N_k - i_k)  mod N_k, N_k being the signal size at dim k,
+//             * is the conjugate operator.
+//
+// Therefore, in such cases, FFT libraries return only roughly half of the
+// values to avoid redundancy:
+//
+//     X[:, :, ..., :floor(N / 2) + 1]
+//
+// This is also the assumption in cuFFT and MKL. In ATen SpectralOps, such
+// halved signal will also be returned by default (flag onesided=True).
+// The following infer_ft_real_to_complex_onesided_size function calculates the
+// onesided size from the twosided size.
+//
+// Note that this loses some information about the size of signal at last
+// dimension. E.g., both 11 and 10 maps to 6. Hence, the following
+// infer_ft_complex_to_real_onesided_size function takes in optional parameter
+// to infer the twosided size from given onesided size.
+//
+// cuFFT doc: http://docs.nvidia.com/cuda/cufft/index.html#multi-dimensional
+// MKL doc: https://software.intel.com/en-us/mkl-developer-reference-c-dfti-complex-storage-dfti-real-storage-dfti-conjugate-even-storage#CONJUGATE_EVEN_STORAGE
+
+inline int64_t infer_ft_real_to_complex_onesided_size(int64_t real_size) {
+  return (real_size / 2) + 1;
+}
+
+inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size,
+                                                      int64_t expected_size=-1) {
+  int64_t base = (complex_size - 1) * 2;
+  if (expected_size < 0) {
+    return base + 1;
+  } else if (base == expected_size) {
+    return base;
+  } else if (base + 1 == expected_size) {
+    return base + 1;
+  } else {
+    std::ostringstream ss;
+    ss << "expected real signal size " << expected_size << " is incompatible "
+       << "with onesided complex frequency size " << complex_size;
+    TORCH_CHECK(false, ss.str());
+  }
+}
+
+using fft_fill_with_conjugate_symmetry_fn =
+    void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes,
+             IntArrayRef in_strides, const void* in_data,
+             IntArrayRef out_strides, void* out_data);
+DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub)
+
+// In real-to-complex transform, cuFFT and MKL only fill half of the values
+// due to conjugate symmetry. This function fills in the other half of the full
+// fft by using the Hermitian symmetry in the signal.
+// self should be the shape of the full signal and dims.back() should be the
+// one-sided dimension.
+// See NOTE [ Fourier Transform Conjugate Symmetry ]
+TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims);
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/StridedRandomAccessor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/StridedRandomAccessor.h
new file mode 100644
index 0000000000000000000000000000000000000000..911a18e1b6ed6a5739b24e4eb60052527b43fc4c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/StridedRandomAccessor.h
@@ -0,0 +1,306 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+namespace at::native {
+
+// (Const)StridedRandomAccessor is a
+// (const) random access iterator defined over
+// a strided array.
+
+// The traits below are to introduce __restrict__
+// modifier on different platforms.
+
+template 
+struct DefaultPtrTraits {
+  using PtrType = T*;
+};
+
+#if (defined(_WIN32) || defined(_WIN64))
+#define RESTRICT __restrict
+#else
+#define RESTRICT __restrict__
+#endif
+
+template 
+struct RestrictPtrTraits {
+  using PtrType = T* RESTRICT;
+};
+
+template <
+  typename T,
+  typename index_t = int64_t,
+  template  class PtrTraits = DefaultPtrTraits
+>
+class ConstStridedRandomAccessor {
+public:
+  using difference_type = index_t;
+  using value_type = const T;
+  using pointer = const typename PtrTraits::PtrType;
+  using reference = const value_type&;
+  using iterator_category = std::random_access_iterator_tag;
+
+  using PtrType = typename PtrTraits::PtrType;
+  using index_type = index_t;
+
+  // Constructors {
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor(PtrType ptr, index_t stride)
+    : ptr{ptr}, stride{stride}
+  {}
+
+  C10_HOST_DEVICE
+  explicit ConstStridedRandomAccessor(PtrType ptr)
+    : ptr{ptr}, stride{static_cast(1)}
+  {}
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor()
+    : ptr{nullptr}, stride{static_cast(1)}
+  {}
+  // }
+
+  // Pointer-like operations {
+  C10_HOST_DEVICE
+  reference operator*() const {
+    return *ptr;
+  }
+
+  C10_HOST_DEVICE
+  const value_type* operator->() const {
+    return reinterpret_cast(ptr);
+  }
+
+  C10_HOST_DEVICE
+  reference operator[](index_t idx) const {
+    return ptr[idx * stride];
+  }
+  // }
+
+  // Prefix/postfix increment/decrement {
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator++() {
+    ptr += stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator++(int) {
+    ConstStridedRandomAccessor copy(*this);
+    ++*this;
+    return copy;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator--() {
+    ptr -= stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator--(int) {
+    ConstStridedRandomAccessor copy(*this);
+    --*this;
+    return copy;
+  }
+  // }
+
+  // Arithmetic operations {
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator+=(index_t offset) {
+    ptr += offset * stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator+(index_t offset) const {
+    return ConstStridedRandomAccessor(ptr + offset * stride, stride);
+  }
+
+  C10_HOST_DEVICE
+  friend ConstStridedRandomAccessor operator+(
+    index_t offset,
+    const ConstStridedRandomAccessor& accessor
+  ) {
+    return accessor + offset;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator-=(index_t offset) {
+    ptr -= offset * stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator-(index_t offset) const {
+    return ConstStridedRandomAccessor(ptr - offset * stride, stride);
+  }
+
+  // Note that this operator is well-defined when `this` and `other`
+  // represent the same sequences, i.e. when
+  // 1. this.stride == other.stride,
+  // 2. |other - this| / this.stride is an Integer.
+  C10_HOST_DEVICE
+  difference_type operator-(const ConstStridedRandomAccessor& other) const {
+    return (ptr - other.ptr) / stride;
+  }
+  // }
+
+  // Comparison operators {
+  C10_HOST_DEVICE
+  bool operator==(const ConstStridedRandomAccessor& other) const {
+    return (ptr == other.ptr) && (stride == other.stride);
+  }
+
+  C10_HOST_DEVICE
+  bool operator!=(const ConstStridedRandomAccessor& other) const {
+    return !(*this == other);
+  }
+
+  C10_HOST_DEVICE
+  bool operator<(const ConstStridedRandomAccessor& other) const {
+    return ptr < other.ptr;
+  }
+
+  C10_HOST_DEVICE
+  bool operator<=(const ConstStridedRandomAccessor& other) const {
+    return (*this < other) || (*this == other);
+  }
+
+  C10_HOST_DEVICE
+  bool operator>(const ConstStridedRandomAccessor& other) const {
+    return !(*this <= other);
+  }
+
+  C10_HOST_DEVICE
+  bool operator>=(const ConstStridedRandomAccessor& other) const {
+    return !(*this < other);
+  }
+  // }
+
+protected:
+  PtrType ptr;
+  index_t stride;
+};
+
+template <
+  typename T,
+  typename index_t = int64_t,
+  template  class PtrTraits = DefaultPtrTraits
+>
+class StridedRandomAccessor
+  : public ConstStridedRandomAccessor {
+public:
+  using difference_type = index_t;
+  using value_type = T;
+  using pointer = typename PtrTraits::PtrType;
+  using reference = value_type&;
+
+  using BaseType = ConstStridedRandomAccessor;
+  using PtrType = typename PtrTraits::PtrType;
+
+  // Constructors {
+  C10_HOST_DEVICE
+  StridedRandomAccessor(PtrType ptr, index_t stride)
+    : BaseType(ptr, stride)
+  {}
+
+  C10_HOST_DEVICE
+  explicit StridedRandomAccessor(PtrType ptr)
+    : BaseType(ptr)
+  {}
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor()
+    : BaseType()
+  {}
+  // }
+
+  // Pointer-like operations {
+  C10_HOST_DEVICE
+  reference operator*() const {
+    return *this->ptr;
+  }
+
+  C10_HOST_DEVICE
+  value_type* operator->() const {
+    return reinterpret_cast(this->ptr);
+  }
+
+  C10_HOST_DEVICE
+  reference operator[](index_t idx) const {
+    return this->ptr[idx * this->stride];
+  }
+  // }
+
+  // Prefix/postfix increment/decrement {
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator++() {
+    this->ptr += this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator++(int) {
+    StridedRandomAccessor copy(*this);
+    ++*this;
+    return copy;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator--() {
+    this->ptr -= this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator--(int) {
+    StridedRandomAccessor copy(*this);
+    --*this;
+    return copy;
+  }
+  // }
+
+  // Arithmetic operations {
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator+=(index_t offset) {
+    this->ptr += offset * this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator+(index_t offset) const {
+    return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
+  }
+
+  C10_HOST_DEVICE
+  friend StridedRandomAccessor operator+(
+    index_t offset,
+    const StridedRandomAccessor& accessor
+  ) {
+    return accessor + offset;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator-=(index_t offset) {
+    this->ptr -= offset * this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator-(index_t offset) const {
+    return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
+  }
+
+  // Note that here we call BaseType::operator- version
+  C10_HOST_DEVICE
+  difference_type operator-(const BaseType& other) const {
+    return (static_cast(*this) - other);
+  }
+  // }
+};
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h
new file mode 100644
index 0000000000000000000000000000000000000000..a773709dfe66782f4c496dba20daf5a8cf1d1bb9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h
@@ -0,0 +1,107 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+// Indexing tensors by tensors
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+struct TensorIterator;
+}
+
+namespace at::native {
+
+using index_put_with_sort_fn = void (*)(
+    Tensor&,
+    const c10::List>&,
+    const Tensor&,
+    bool accumulate,
+    bool unsafe);
+using index_put_with_sort_quantized_fn = void (*)(
+    Tensor& self,
+    const c10::List>& indices,
+    const Tensor& value,
+    double scale,
+    int zero_point,
+    bool unsafe);
+using gather_fn = void (*)(
+    const Tensor& result,
+    const Tensor& self,
+    int64_t dim,
+    const Tensor& index);
+using scatter_fn = void (*)(
+    const Tensor& self,
+    int64_t dim,
+    const Tensor& index,
+    const Tensor& src);
+using scatter_fill_fn = void (*)(
+    const Tensor& self,
+    int64_t dim,
+    const Tensor& index,
+    const Scalar& src);
+using scatter_add_fn = void (*)(
+    const Tensor& self,
+    int64_t dim,
+    const Tensor& index,
+    const Tensor& src);
+using scatter_reduce_fn = void (*)(
+    const Tensor& self,
+    const int64_t dim,
+    const Tensor& index,
+    const Tensor& src,
+    const ReductionType& reduce);
+using scatter_scalar_reduce_fn = void (*)(
+    const Tensor& self,
+    const int64_t dim,
+    const Tensor& index,
+    const Scalar& value,
+    const ReductionType& reduce);
+using scatter_reduce_two_fn = void (*)(
+    const Tensor& self,
+    const int64_t dim,
+    const Tensor& index,
+    const Tensor& src,
+    const ReductionType& reduce);
+
+DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub)
+DECLARE_DISPATCH(
+    index_put_with_sort_quantized_fn,
+    index_put_with_sort_quantized_stub)
+DECLARE_DISPATCH(gather_fn, gather_stub)
+DECLARE_DISPATCH(scatter_fn, scatter_stub)
+DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub)
+DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub)
+DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub)
+DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub)
+DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub)
+
+TORCH_API Tensor& index_out(
+    Tensor& result,
+    const Tensor& self,
+    const c10::List>& indices);
+
+using scatter_add_expanded_index_fn =
+    void (*)(const Tensor&, const Tensor&, const Tensor&);
+using scatter_reduce_expanded_index_fn = void (*)(
+    const Tensor&,
+    const Tensor&,
+    const Tensor&,
+    const ReductionType& reduce,
+    bool);
+using gather_expanded_index_fn =
+    void (*)(const Tensor&, const Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub)
+DECLARE_DISPATCH(
+    scatter_reduce_expanded_index_fn,
+    scatter_reduce_expanded_index_stub)
+DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..014ba51a02e0906fccd57e8ad4f597bb2067388f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h
@@ -0,0 +1,110 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at::native {
+namespace {
+#ifndef STRIP_ERROR_MESSAGES
+inline std::string shapes_as_str(TensorList tensors) {
+  std::ostringstream os;
+  bool first = true;
+  for (auto& tensor : tensors) {
+    if (tensor.defined()) {
+      if (!first) {
+        os << ", ";
+      }
+      os << tensor.sizes();
+      first = false;
+    }
+  }
+  return os.str();
+}
+#endif
+} // anonymous namespace
+
+inline std::tuple canDispatchToMaskedFill(
+    const Tensor& self,
+    const torch::List>& indices,
+    const Tensor& value) {
+  if (!(value.numel() == 1 && value.device().is_cpu())) {
+    return std::make_tuple(false, Tensor());
+  }
+  int64_t num_ind = 0;
+  Tensor mask;
+  auto self_device = self.device();
+  for (const std::optional& i : indices) {
+    if (!i.has_value() || !(*i).defined()) {
+      if (!mask.defined()) {
+        num_ind++;
+      }
+    } else {
+      const Tensor& index = *i;
+      if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
+          index.device() != self_device || mask.defined()) {
+        return std::make_tuple(false, Tensor());
+      } else {
+        mask = index;
+        for (const auto j : c10::irange(index.dim())) {
+          int64_t srcIdx = num_ind + j;
+          TORCH_CHECK_INDEX(
+              index.size(j) == self.size(srcIdx),
+              "The shape of the mask ",
+              index.sizes(),
+              " at index ",
+              j,
+              " does not match the shape of the indexed tensor ",
+              self.sizes(),
+              " at index ",
+              srcIdx);
+        }
+        num_ind += mask.ndimension();
+      }
+    }
+  }
+  for ([[maybe_unused]] const auto i :
+       c10::irange(num_ind, self.ndimension())) {
+    mask = mask.unsqueeze(-1);
+  }
+  return std::make_tuple(true, mask);
+}
+
+inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
+  checkIndexTensorTypes(orig, /*allow_int*/ true);
+  // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more
+  // LongTensors
+  auto indices = expandTensors(self, orig, /*ensure_same_device=*/true);
+  // next broadcast all index tensors together
+  try {
+    indices = expand_outplace(indices);
+  } catch (std::exception&) {
+    TORCH_CHECK_INDEX(
+        false,
+        "shape mismatch: indexing tensors could not be broadcast together"
+        " with shapes ",
+        shapes_as_str(indices));
+  }
+  // add missing null Tensors so that it matches self.dim()
+  while (indices.size() < (size_t)self.dim()) {
+    indices.emplace_back();
+  }
+  // if the non-null indices are not all adjacent, transpose self and indices
+  // together so that they're adjacent at the front
+  if (!hasContiguousSubspace(indices)) {
+    std::tie(self, indices) = transposeToFront(self, indices);
+  }
+  for (auto& indice : indices) {
+    if (indice.defined() && indice.dtype() == at::kInt) {
+      indice = indice.to(at::kLong);
+    }
+  }
+
+  return AdvancedIndex(self, indices);
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorCompare.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorCompare.h
new file mode 100644
index 0000000000000000000000000000000000000000..76a5cde2a6373a56d39052f8bfc21ab131b1ba6d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorCompare.h
@@ -0,0 +1,61 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+class Tensor;
+struct TensorIterator;
+struct TensorIteratorBase;
+} // namespace at
+
+namespace at::native {
+
+using reduce_minmax_fn =
+    void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
+using structured_reduce_minmax_fn =
+    void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);
+
+DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub)
+DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub)
+
+using where_fn = void (*)(TensorIterator&);
+DECLARE_DISPATCH(where_fn, where_kernel)
+
+using is_infinity_op_fn = void (*)(TensorIteratorBase&);
+DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub)
+DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub)
+
+using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
+DECLARE_DISPATCH(mode_fn, mode_stub)
+
+using clamp_tensor_fn = void (*)(TensorIteratorBase&);
+DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub)
+
+namespace detail {
+enum class ClampLimits { Min, Max, MinMax };
+}
+
+DECLARE_DISPATCH(
+    void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&),
+    clamp_scalar_stub)
+DECLARE_DISPATCH(
+    void (*)(TensorIteratorBase&, c10::Scalar),
+    clamp_min_scalar_stub)
+DECLARE_DISPATCH(
+    void (*)(TensorIteratorBase&, c10::Scalar),
+    clamp_max_scalar_stub)
+
+using isin_default_fn =
+    void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
+DECLARE_DISPATCH(isin_default_fn, isin_default_stub)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorConversions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorConversions.h
new file mode 100644
index 0000000000000000000000000000000000000000..cdd9cb0310602283ed1b0c2d2ef42c1094985f7d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorConversions.h
@@ -0,0 +1,36 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+class Tensor;
+namespace native {
+bool to_will_alias(
+    const Tensor& self,
+    std::optional dtype,
+    std::optional layout,
+    std::optional device,
+    bool copy,
+    std::optional optional_memory_format);
+
+Tensor to_meta(const Tensor& tensor);
+std::optional to_meta(const std::optional& tensor);
+std::vector to_meta(at::ITensorListRef t_list);
+Tensor dense_to_sparse_with_mask(
+    const Tensor& self,
+    const Tensor& mask,
+    std::optional layout,
+    OptionalIntArrayRef blocksize,
+    std::optional dense_dim_opt);
+
+} // namespace native
+} // namespace at
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorDimApply.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorDimApply.h
new file mode 100644
index 0000000000000000000000000000000000000000..8d3fcbe72b2fd8c0c4fb55574d4acd7dc843d752
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorDimApply.h
@@ -0,0 +1,72 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+#include 
+
+namespace at::native {
+// input tensors are non-zero dim and non-empty
+template 
+
+void tensor_dim_apply3(
+    const Tensor& self,
+    Tensor& values,
+    Tensor& indices,
+    int64_t dim,
+    Function func) {
+  int ndims = self.dim();
+  int tensor_dim_apply_has_finished = 0;
+  std::vector counter(ndims, 0);
+  const T1* self_data = self.const_data_ptr();
+  T1* values_data = values.data_ptr();
+  T2* indices_data = indices.data_ptr();
+  int64_t self_stride = self.stride(dim);
+  int64_t values_stride = values.stride(dim);
+  int64_t indices_stride = indices.stride(dim);
+  int self_dim_size = self.size(dim);
+
+  while (!tensor_dim_apply_has_finished) {
+    func(
+        self_data,
+        values_data,
+        indices_data,
+        self_dim_size,
+        self_stride,
+        values_stride,
+        indices_stride);
+    if (ndims == 1) {
+      break;
+    }
+    for (const auto dim_i : c10::irange(ndims)) {
+      if (dim_i == dim) {
+        if (dim_i == (ndims - 1)) {
+          tensor_dim_apply_has_finished = 1;
+          break;
+        }
+        continue;
+      }
+      counter[dim_i]++;
+      self_data += self.stride(dim_i);
+      values_data += values.stride(dim_i);
+      indices_data += indices.stride(dim_i);
+
+      if (counter[dim_i] == self.size(dim_i)) {
+        if (dim_i == ndims - 1) {
+          tensor_dim_apply_has_finished = 1;
+          break;
+        } else {
+          self_data -= counter[dim_i] * self.stride(dim_i);
+          values_data -= counter[dim_i] * values.stride(dim_i);
+          indices_data -= counter[dim_i] * indices.stride(dim_i);
+          counter[dim_i] = 0;
+        }
+      } else {
+        break;
+      }
+    }
+  }
+}
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h
new file mode 100644
index 0000000000000000000000000000000000000000..d5ec9b0f2b99536dde9ea31648095ab9fb022122
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h
@@ -0,0 +1,57 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// This file includes utilities for dynamic_casting done by TensorIterator, see
+// CUDALoops.cuh and Loops.h.
+
+// dynamic_casting handles when the types expected by the iterator do not match
+// the types of the arguments to the function that is being called. On CUDA, the
+// cast is currently pushed down into the kernel (for performance reasons). On
+// CPU, there is currently an internal assert that a dynamic_cast is not needed.
+
+namespace at::native {
+
+// `needs_dynamic_casting` compares the types expected by iterator
+// (i.e. dtypes of the operands) with the actual type of the arguments
+// (and returns) of func_t
+template ::arity>
+struct needs_dynamic_casting {
+  static bool check(TensorIteratorBase& iter) {
+    using traits = function_traits;
+    using cpp_type = typename traits::template arg::type;
+    using cpp_map = c10::CppTypeToScalarType;
+
+    if (iter.input_dtype(nargs - 1) != cpp_map::value) {
+      return true;
+    }
+    return needs_dynamic_casting::check(iter);
+  }
+};
+
+template 
+struct needs_dynamic_casting {
+  static bool check(TensorIteratorBase& iter) {
+    using traits = function_traits;
+    using cpp_type = typename traits::result_type;
+
+    // we could assert output numbers are correct here, but checks
+    // (including arity) are currently pushed outside of this struct.
+    if constexpr (std::is_void_v) {
+      return false;
+    } else {
+      return iter.dtype(0) != c10::CppTypeToScalarType::value;
+    }
+  }
+};
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorProperties.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorProperties.h
new file mode 100644
index 0000000000000000000000000000000000000000..53cb90246892b6b0e45fef516c4c70516cf14e3c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorProperties.h
@@ -0,0 +1,17 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+// See NOTE: [Tensor vs. TensorBase]
+namespace at {
+class TensorBase;
+}
+
+namespace at::native {
+
+TORCH_API bool cudnn_is_acceptable(const TensorBase& self);
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TransposeType.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TransposeType.h
new file mode 100644
index 0000000000000000000000000000000000000000..df8db2787d469eb272ed3d1dfd44787e42ae3c5c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TransposeType.h
@@ -0,0 +1,32 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+#include 
+
+C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
+
+namespace at::native {
+
+// Used as an interface between the different BLAS-like libraries
+enum class TransposeType {
+  NoTranspose,
+  Transpose,
+  ConjTranspose,
+};
+
+// Transforms TransposeType into the BLAS / LAPACK format
+static inline char to_blas(TransposeType trans) {
+  switch (trans) {
+    case TransposeType::Transpose: return 'T';
+    case TransposeType::NoTranspose: return 'N';
+    case TransposeType::ConjTranspose: return 'C';
+  }
+  TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
+
+}  // namespace at::native
+
+C10_DIAGNOSTIC_POP()
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TriangularOpsUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TriangularOpsUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..86a58a2e01abe58216af6f3b92f877a31bb5c3c5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TriangularOpsUtils.h
@@ -0,0 +1,62 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#include 
+#include 
+
+namespace at::native {
+
+/*
+ * Given batches of matrices with arbitrary batch dim,
+ * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
+ */
+static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
+  int64_t result = 1;
+  for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
+    if (batched_matrices.stride(i) != 0) {
+      result *= batched_matrices.size(i);
+    }
+  }
+  return result;
+}
+
+/* Checks a necessary property for the triu and tril implementations, hence the name.
+ * Here batch contiguity is checked for tensors with greater than 4 dimensions.
+ * Contiguous tensors and tensors with less than 3 dimensions pass this check
+ */
+static inline std::tuple checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
+  // Complete contiguity is the most desired property, which is why
+  // we return true if the tensor is contiguous
+  if (tensor.is_contiguous()) {
+    auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
+    if (tensor.strides() == default_strides_for_size) {
+      return std::make_tuple(true, tensor);
+    } else {
+      return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
+    }
+  }
+
+  int64_t dims = tensor.dim();
+
+  // Tensors with dimension less than 4 are handled by default
+  if (allow_zero_stride && dims <= 3) {
+    return std::make_tuple(true, tensor);
+  }
+
+  int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
+  for (int64_t i = dims - 3; i >= 0; i--) {
+    // Skip trivial dimension;
+    if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
+      continue;
+    }
+    if (expected_stride != tensor.stride(i)) {
+      return std::make_tuple(false, tensor.contiguous());
+    }
+    expected_stride *= tensor.size(i);
+  }
+  return std::make_tuple(true, tensor);
+}
+
+}  // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Unfold2d.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Unfold2d.h
new file mode 100644
index 0000000000000000000000000000000000000000..e686c33e0a2541eed8c308a4dc77bc160d1ccc2f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Unfold2d.h
@@ -0,0 +1,53 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+using unfold2d_copy_fn = void (*)(
+    ScalarType dtype,
+    void *finput,
+    const void *input,
+    int64_t kH,
+    int64_t kW,
+    int64_t dH,
+    int64_t dW,
+    int64_t padH,
+    int64_t padW,
+    int64_t n_input_plane,
+    int64_t input_height,
+    int64_t input_width,
+    int64_t output_height,
+    int64_t output_width,
+    bool is_channels_last
+);
+
+using unfold2d_acc_fn = void (*)(
+    ScalarType dtype,
+    void *finput,
+    void *input,
+    int64_t kH,
+    int64_t kW,
+    int64_t dH,
+    int64_t dW,
+    int64_t padH,
+    int64_t padW,
+    int64_t n_input_plane,
+    int64_t input_height,
+    int64_t input_width,
+    int64_t output_height,
+    int64_t output_width,
+    bool is_channels_last
+);
+
+DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub)
+DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UnfoldBackward.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UnfoldBackward.h
new file mode 100644
index 0000000000000000000000000000000000000000..460ca209afb397ff15af4cde24cc6665f6d380f9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UnfoldBackward.h
@@ -0,0 +1,115 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+
+using unfold_backward_fn = void (*)(
+  Tensor& grad_in,
+  const Tensor& grad,
+  int64_t dim,
+  int64_t size,
+  int64_t step
+);
+
+DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub)
+
+namespace {
+
+// Note on naming: it is unconventional.
+// grad_in does not mean that it is a gradient wrt to input,
+// grad_in/grad_out is just an input/output of unfold_backward kernel.
+
+[[maybe_unused]] TensorIterator _make_unfold_backward_iter_over_grad_out(
+    Tensor& grad_out,
+    const Tensor& grad_in,
+    int64_t dim,
+    int64_t size,
+    int64_t step) {
+  dim = maybe_wrap_dim(dim, grad_out.dim());
+  // last dim stores the folds
+
+  auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
+  auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
+  // dictates the number of elements to iterate over
+  // in dimension `dim`
+  auto iter_dim_size = std::min(
+    grad_out_dim_size,
+    (grad_in_dim_size - 1) * step + size
+  );
+
+  /* prepare grad_out for TensorIterator { */
+  auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
+  auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
+  grad_out_sizes[dim] = iter_dim_size;
+  auto grad_out_restrided = grad_out.as_strided(
+    grad_out_sizes, grad_out_strides
+  );
+  /* } */
+
+  /* prepare grad_in for TensorIterator { */
+  auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
+  auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
+
+  // set strides for dim to 0
+  // and size to 1 because
+  // this dimension is indexed inside the kernel
+  grad_in_strides[dim] = 0;
+  grad_in_sizes[dim] = 1;
+
+  grad_in_strides.pop_back();
+  grad_in_sizes.pop_back();
+
+  auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
+    grad_in_sizes, grad_in_strides
+  );
+  /* } */
+
+  // During the TensorIterator iteration we have to know
+  // i_dim in grad_out[i_1,...,i_dim,...i_n],
+  // idx_dim stores this information
+  /* prepare idx_dim for TensorIterator { */
+  auto idx_dim = at::arange(
+    0, iter_dim_size, grad_in.options().dtype(at::kLong)
+  );
+
+  auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
+
+  auto idx_dim_strides = std::vector(grad_out_dim, 0);
+  auto idx_dim_sizes = std::vector(grad_out_dim, 1);
+
+  idx_dim_strides[dim] = 1;
+  idx_dim_sizes[dim] = iter_dim_size;
+
+  // idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
+  auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
+  /* } */
+
+  auto iter = TensorIteratorConfig()
+    .set_check_mem_overlap(false)
+    .check_all_same_dtype(false)
+    .resize_outputs(false)
+    .add_owned_output(grad_out_restrided)
+    .add_owned_const_input(grad_in_restrided)
+    .add_owned_const_input(idx_dim_restrided)
+    .build();
+
+  return iter;
+}
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/batch_norm.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/batch_norm.h
new file mode 100644
index 0000000000000000000000000000000000000000..5a79be420f02438215df64272a8c5aa46c222e80
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/batch_norm.h
@@ -0,0 +1,43 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
+    const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
+using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
+using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
+        const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
+
+DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub)
+DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub)
+DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub)
+
+// TensorAccessor when it is defined to work around undefined...
+template 
+static TensorAccessor conditional_accessor_1d(const Tensor& t) {
+  if (! t.defined()) {
+    return TensorAccessor(nullptr, nullptr, nullptr);
+  }
+  return t.accessor();
+}
+
+template 
+static scalar_t* conditional_data_ptr(const Tensor& t) {
+  if constexpr (std::is_const_v) {
+    return t.defined() ? t.contiguous().const_data_ptr()
+                      : nullptr;
+  } else {
+    return t.defined() ? t.contiguous().data_ptr()
+                      : nullptr;
+  }
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/group_norm.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/group_norm.h
new file mode 100644
index 0000000000000000000000000000000000000000..19ea16e8dfff68d95b7e8f4644b11f985a217134
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/group_norm.h
@@ -0,0 +1,47 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using forward_fn = void (*)(
+    const Tensor& /* X */,
+    const Tensor& /* gamma */,
+    const Tensor& /* beta */,
+    int64_t /* N */,
+    int64_t /* C */,
+    int64_t /* HxW */,
+    int64_t /* group */,
+    double /* eps */,
+    Tensor& /* Y */,
+    Tensor& /* mean */,
+    Tensor& /* rstd */);
+
+using backward_fn = void (*)(
+    const Tensor& /* dY */,
+    const Tensor& /* X */,
+    const Tensor& /* mean */,
+    const Tensor& /* rstd */,
+    const Tensor& /* gamma */,
+    int64_t /* N */,
+    int64_t /* C */,
+    int64_t /* HxW */,
+    int64_t /* group */,
+    Tensor& /* dX */,
+    Tensor& /* dgamma */,
+    Tensor& /* dbeta */);
+
+DECLARE_DISPATCH(forward_fn, GroupNormKernel)
+DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel)
+
+} // namespace native
+} // namespace at
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/im2col.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/im2col.h
new file mode 100644
index 0000000000000000000000000000000000000000..7ca8091f4a0eac60acba84b08ac0b11e68402f75
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/im2col.h
@@ -0,0 +1,154 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at::native {
+
+template 
+static void im2col(
+    const T* data_im,
+    const int64_t channels,
+    const int64_t height,
+    const int64_t width,
+    const int64_t output_height,
+    const int64_t output_width,
+    const int64_t kernel_h,
+    const int64_t kernel_w,
+    const int64_t pad_h,
+    const int64_t pad_w,
+    const int64_t stride_h,
+    const int64_t stride_w,
+    const int64_t dilation_h,
+    const int64_t dilation_w,
+    T* data_col,
+    bool is_channels_last = false) {
+  const int64_t height_col = output_height;
+  const int64_t width_col = output_width;
+  const int64_t channels_col = channels * kernel_h * kernel_w;
+
+  if (is_channels_last) {
+    at::parallel_for(0, height_col * width_col, 0, [&](int64_t begin, int64_t end) {
+      int64_t h_col{0}, w_col{0};
+      data_index_init(begin, h_col, height_col, w_col, width_col);
+
+      for (const auto i_col : c10::irange(begin, end)) {
+        for (const auto h_offset : c10::irange(kernel_h)) {
+          int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+          for (const auto w_offset : c10::irange(kernel_w)) {
+            int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
+
+            const T* slice_im = data_im + (h_im * width + w_im) * channels;
+            T* slice_col = data_col + (i_col * kernel_h * kernel_w + h_offset * kernel_w + w_offset) * channels;
+
+            if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+              std::copy_n(slice_im, channels, slice_col);
+            } else {
+              std::fill_n(slice_col, channels, T(0));
+            }
+          }
+        }
+
+        // move the next index
+        data_index_step(h_col, height_col, w_col, width_col);
+      }
+    });
+  } else {
+    at::parallel_for(0, channels_col, 0, [&](int64_t begin, int64_t end) {
+      int64_t c_im{0}, h_offset{0}, w_offset{0};
+      data_index_init(begin, c_im, channels, h_offset, kernel_h, w_offset, kernel_w);
+
+      for (const auto c_col : c10::irange(begin, end)) {
+        for (const auto h_col : c10::irange(height_col)) {
+          int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+          for (const auto w_col : c10::irange(width_col)) {
+            int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
+            data_col[(c_col * height_col + h_col) * width_col + w_col] =
+                (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width)
+                ? c10::load(&(data_im[(c_im * height + h_im) * width + w_im]))
+                : static_cast(0);
+          }
+        }
+
+        // move to the next index
+        data_index_step(c_im, channels, h_offset, kernel_h, w_offset, kernel_w);
+      }
+    });
+  }
+}
+
+template 
+static void col2im(
+    const T* data_col,
+    const int64_t channels,
+    const int64_t height,
+    const int64_t width,
+    const int64_t output_height,
+    const int64_t output_width,
+    const int64_t kernel_h,
+    const int64_t kernel_w,
+    const int64_t pad_h,
+    const int64_t pad_w,
+    const int64_t stride_h,
+    const int64_t stride_w,
+    const int64_t dilation_h,
+    const int64_t dilation_w,
+    T* data_im,
+    bool is_channels_last = false) {
+  std::fill_n(data_im, height * width * channels, T(0));
+
+  const int64_t height_col = output_height;
+  const int64_t width_col = output_width;
+  const int64_t channels_col = channels * kernel_h * kernel_w;
+
+  if (is_channels_last) {
+    for (const auto h_col : c10::irange(height_col)) {
+      for (const auto w_col : c10::irange(width_col)) {
+        for (const auto h_offset : c10::irange(kernel_h)) {
+          int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+          for (const auto w_offset : c10::irange(kernel_w)) {
+            int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
+
+            T* slice_im = data_im + (h_im * width + w_im) * channels;
+            const T* slice_col = data_col + ((h_col * width_col + w_col) * kernel_h * kernel_w
+                + h_offset * kernel_w + w_offset) * channels;
+
+            if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) {
+              std::transform(slice_col, slice_col + channels, slice_im, slice_im, std::plus());
+            }
+          }
+        }
+      }
+    }
+  } else {
+    for (const auto c_col : c10::irange(channels_col)) {
+      int64_t w_offset = c_col % kernel_w;
+      int64_t h_offset = (c_col / kernel_w) % kernel_h;
+      int64_t c_im = c_col / kernel_h / kernel_w;
+
+      for (const auto h_col : c10::irange(height_col)) {
+        int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+        for (const auto w_col : c10::irange(width_col)) {
+          int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
+
+          if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
+            data_im[(c_im * height + h_im) * width + w_im] +=
+                data_col[(c_col * height_col + h_col) * width_col + w_col];
+        }
+      }
+    }
+  }
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/layer_norm.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/layer_norm.h
new file mode 100644
index 0000000000000000000000000000000000000000..a667349ad6fe9933d23a76b34345affeb650bdea
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/layer_norm.h
@@ -0,0 +1,157 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+namespace at::native {
+
+namespace {
+
+C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint(
+    const Tensor& input,
+    c10::SymIntArrayRef normalized_shape,
+    const Tensor& weight /* optional */) {
+
+  const int normalized_ndim = normalized_shape.size();
+  TORCH_CHECK(
+      normalized_ndim >= 1,
+      "Expected normalized_shape to be at least 1-dimensional, i.e., ",
+      "containing at least one element, but got normalized_shape = ",
+      normalized_shape);
+  if (weight.defined()) {
+    TORCH_SYM_CHECK(
+        sym_equals(weight.sym_sizes(), normalized_shape),
+        "Expected weight to be of same shape as normalized_shape, but got ",
+        "weight of shape ",
+        weight.sym_sizes(),
+        " and normalized_shape = ",
+        normalized_shape);
+  }
+
+  const auto input_ndim = input.dim();
+  const auto input_shape = input.sym_sizes();
+  TORCH_CHECK_VALUE(
+      input_ndim >= normalized_ndim,
+      "Input tensor must have at least ", normalized_ndim, " dimensions, but got ", input_ndim);
+
+  auto expect_input_shape_msg = c10::str(
+      "Given normalized_shape=", normalized_shape,
+      ", expected input with shape [*", c10::Join(", ", normalized_shape),
+      "], but got input of size", input_shape);
+
+  TORCH_SYM_CHECK(
+      sym_equals(input_shape.slice(input_ndim - normalized_ndim), normalized_shape),
+      expect_input_shape_msg);
+}
+
+C10_ALWAYS_INLINE std::pair _check_layer_norm_inputs(
+    const Tensor& input,
+    IntArrayRef normalized_shape,
+    const Tensor& weight /* optional */,
+    const Tensor& bias /* optional */) {
+
+  const int normalized_ndim = normalized_shape.size();
+  TORCH_CHECK(
+      normalized_ndim >= 1,
+      "Expected normalized_shape to be at least 1-dimensional, i.e., ",
+      "containing at least one element, but got normalized_shape = ",
+      normalized_shape);
+  TORCH_CHECK(
+      !weight.defined() || weight.sizes().equals(normalized_shape),
+      "Expected weight to be of same shape as normalized_shape, but got ",
+      "weight of shape ",
+      weight.sizes(),
+      " and normalized_shape = ",
+      normalized_shape);
+  TORCH_CHECK(
+      !bias.defined() || bias.sizes().equals(normalized_shape),
+      "Expected bias to be of same shape as normalized_shape, but got ",
+      "bias of shape ",
+      bias.sizes(),
+      " and normalized_shape = ",
+      normalized_shape);
+
+  const auto input_shape = input.sizes();
+  const auto input_ndim = input.dim();
+
+  if (input_ndim < normalized_ndim ||
+      !input_shape.slice(input_ndim - normalized_ndim)
+           .equals(normalized_shape)) {
+    std::stringstream ss;
+    ss << "Given normalized_shape=" << normalized_shape
+       << ", expected input with shape [*";
+    for (auto size : normalized_shape) {
+      ss << ", " << size;
+    }
+    ss << "], but got input of size" << input_shape;
+    TORCH_CHECK(false, ss.str());
+  }
+
+  const int axis = input_ndim - normalized_ndim;
+  const int64_t M =
+      c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
+  const int64_t N =
+      c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
+
+  return std::make_pair(M, N);
+}
+
+} // namespace
+
+void layer_norm_cpu_out(
+    at::Tensor& out,
+    const at::Tensor& input,
+    const Tensor& gamma,
+    const Tensor& beta,
+    double eps,
+    int64_t M,
+    int64_t N);
+
+std::tuple rms_norm_composite(
+    const Tensor& input,
+    IntArrayRef normalized_shape,
+    const std::optional& weight_opt /* optional */,
+    std::optional eps);
+
+Tensor rms_norm_symint(
+    const Tensor& input,
+    c10::SymIntArrayRef normalized_shape,
+    const std::optional& weight_opt /* optional */,
+    std::optional eps);
+
+using forward_fn = void (*)(
+    const Tensor& /* X */,
+    const Tensor& /* gamma */,
+    const Tensor& /* beta */,
+    int64_t /* M */,
+    int64_t /* N */,
+    double /* eps */,
+    Tensor* /* Y */,
+    Tensor* /* mean */,
+    Tensor* /* rstd */);
+
+using backward_fn = void (*)(
+    const Tensor& /* dY */,
+    const Tensor& /* X */,
+    const Tensor& /* mean */,
+    const Tensor& /* rstd */,
+    const Tensor& /* gamma */,
+    int64_t /* M */,
+    int64_t /* N */,
+    Tensor* /* dX */,
+    Tensor* /* dgamma */,
+    Tensor* /* dbeta */);
+
+DECLARE_DISPATCH(forward_fn, LayerNormKernel)
+DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel)
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/vol2col.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/vol2col.h
new file mode 100644
index 0000000000000000000000000000000000000000..454e468ab35edf7fed2d0b28d898a8add76b8164
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/vol2col.h
@@ -0,0 +1,114 @@
+#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
+#pragma once
+
+#include 
+
+namespace at::native {
+
+template 
+void vol2col(
+    const T* data_vol,
+    const int64_t channels,
+    const int64_t depth,
+    const int64_t height,
+    const int64_t width,
+    const int64_t depth_col,
+    const int64_t height_col,
+    const int64_t width_col,
+    const int64_t kT,
+    const int64_t kernel_height,
+    const int64_t kernel_width,
+    const int64_t pT,
+    const int64_t pH,
+    const int64_t pW,
+    const int64_t dT,
+    const int64_t dH,
+    const int64_t dW,
+    const int64_t dilationT,
+    const int64_t dilationH,
+    const int64_t dilationW,
+    T* data_col) {
+  int64_t c, t, h, w;
+  int64_t channels_col = channels * kT * kernel_height * kernel_width;
+  for (c = 0; c < channels_col; ++c) {
+    int64_t w_offset = c % kernel_width;
+    int64_t h_offset = (c / kernel_width) % kernel_height;
+    int64_t t_offset = (c / kernel_width / kernel_height) % kT;
+    int64_t c_vol = c / kT / kernel_height / kernel_width;
+    for (t = 0; t < depth_col; ++t) {
+      int64_t t_pad = t * dT - pT + t_offset * dilationT;
+      for (h = 0; h < height_col; ++h) {
+        int64_t h_pad = h * dH - pH + h_offset * dilationH;
+        for (w = 0; w < width_col; ++w) {
+          int64_t w_pad = w * dW - pW + w_offset * dilationW;
+          if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
+              w_pad >= 0 && w_pad < width)
+            data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
+                data_vol
+                    [((c_vol * depth + t_pad) * height + h_pad) * width +
+                     w_pad];
+          else
+            data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
+                0;
+        }
+      }
+    }
+  }
+}
+
+template 
+void col2vol(
+    const T* data_col,
+    const int64_t channels,
+    const int64_t depth,
+    const int64_t height,
+    const int64_t width,
+    const int64_t out_depth,
+    const int64_t out_height,
+    const int64_t out_width,
+    const int64_t kT,
+    const int64_t kernel_height,
+    const int64_t kernel_width,
+    const int64_t pT,
+    const int64_t pH,
+    const int64_t pW,
+    const int64_t dT,
+    const int64_t dH,
+    const int64_t dW,
+    const int64_t dilationT,
+    const int64_t dilationH,
+    const int64_t dilationW,
+    T* data_vol) {
+  memset(data_vol, 0, sizeof(T) * depth * height * width * channels);
+  int64_t depth_col = out_depth;
+  int64_t height_col = out_height;
+  int64_t width_col = out_width;
+  int64_t channels_col = channels * kT * kernel_height * kernel_width;
+  for (int64_t c = 0; c < channels_col; ++c) {
+    int64_t w_offset = c % kernel_width;
+    int64_t h_offset = (c / kernel_width) % kernel_height;
+    int64_t t_offset = (c / kernel_width / kernel_height) % kT;
+    int64_t c_vol = c / kT / kernel_height / kernel_width;
+    for (int64_t t = 0; t < depth_col; ++t) {
+      int64_t t_pad = t * dT - pT + t_offset * dilationT;
+      for (int64_t h = 0; h < height_col; ++h) {
+        int64_t h_pad = h * dH - pH + h_offset * dilationH;
+        for (int64_t w = 0; w < width_col; ++w) {
+          int64_t w_pad = w * dW - pW + w_offset * dilationW;
+          if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
+              w_pad >= 0 && w_pad < width)
+            data_vol
+                [((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] +=
+                data_col
+                    [((c * depth_col + t) * height_col + h) * width_col + w];
+        }
+      }
+    }
+  }
+}
+
+} // namespace at::native
+
+#else
+#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
+#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5294463aa5c3c325456094e5c92cf5cfb159ed96
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa52033c7d2098b3e88c57499388205f7ccc3865
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81217065e2e870ab39fed24973a5f8cc6b712ad7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3aa930955199c7f83f0d1f4146922d151c73486e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e8faf485c24cedcfdf1480fbd3d4fc99410264e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff1aa5ff8fd16bf999441c6bbca38f457b7742b2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d24a79721e6e8d796ad6007caa31c5bd6a48b447
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decompositions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decompositions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f23d0645cae52da5b9f5b6c6b02bd93f1f4b334c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decompositions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81ec9c4ee13bd3f22a1c1c443a80307aa182723e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..156bdb1260d47eab44618a2891e341342b5a8413
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04affbda275383bc6cac4206a6fdbce0b792e835
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_logging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_logging.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4a7329beb7773108eb626cdf29477f99bdfc335
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_logging.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99e58418c000939f33111cdcabcdab70c1d61158
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_pickle.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_pickle.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85ccc4dedc39f649e2ca779a58fd996fb814c035
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_pickle.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7f803dc7df67dd0a2fa3b7084ad7f9fce38eddd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8e938602cd041c59a6cce880c22bb28d7a05c46
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4eaf72fc2af38f91f58d1762787205459712c5e0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_shape_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_shape_functions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51080405eb701f872ba5eab5ae2870d7aaaa380b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_shape_functions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84d8722103d4ca55cf77b07a266e1f5dd30ae13f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc623fe3b745102a7c38dfaee02ff3f5d1861c9a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c79635e28a32e18f26f5f9f7f5af3576465fe8ce
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bba65a325998f0e3e8f42dfca2c30a4adfde6c7c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c6b7a99459f08708c2e8ee96d30649b02b963c6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/quantized.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/quantized.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a76f83a80d096d89efaa185b116d33d6f4b1d2cf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/quantized.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/supported_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/supported_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..758dad2145c0a1cb982082c5a1845169c8a16145
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/supported_ops.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b52b1f97eb66767588f2012147e60734cdccdc31
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/_passes/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/_passes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c410b8fbb7fd329442aa867c0e39c03cd4f15199
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py
@@ -0,0 +1,46 @@
+"""
+Tools to help with tensor property propagation.
+
+This is not intended to be imported directly; please use the exposed
+functionalities in `torch.jit`.
+"""
+
+from typing import Any
+
+import torch
+from torch import TensorType
+from torch._C import Graph
+
+
+def apply_input_props_using_example(graph: Graph, example_input: list[Any]) -> None:
+    """
+    Applies properties for each tensor in the graph inputs
+    using the example supplied.
+    """
+    graph_inputs = list(graph.inputs())
+    if len(graph_inputs) == 0:
+        return
+
+    # Strip self args off for methods
+    in_0 = graph_inputs[0]
+    if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self":
+        graph_inputs = graph_inputs[1:]
+
+    if not len(graph_inputs) == len(example_input):
+        raise RuntimeError(
+            "Number of inputs in graph does not match number of inputs in the example"
+        )
+
+    for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)):
+        if example_i is None:
+            continue  # Skip the type check
+
+        if isinstance(example_i, torch.Tensor) != isinstance(
+            graph_i.type(), TensorType
+        ):
+            raise RuntimeError(
+                f"Input {i} does not match type of example", graph_i, example_i
+            )
+
+        if isinstance(example_i, torch.Tensor):
+            graph_i.setType(TensorType.create_from_tensor(example_i))  # type: ignore[arg-type]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/mobile/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/mobile/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..608d1c2f7798d84498907c032a2a4acc6f65f7ef
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/jit/mobile/__init__.py
@@ -0,0 +1,244 @@
+# mypy: allow-untyped-defs
+import os
+
+import torch
+from torch.jit._serialization import validate_map_location
+
+
+def _load_for_lite_interpreter(f, map_location=None):
+    r"""
+    Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`.
+
+    Args:
+        f: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+        map_location: a string or torch.device used to dynamically remap
+            storages to an alternative set of devices.
+
+    Returns:
+        A :class:`LiteScriptModule` object.
+
+    Example:
+
+    .. testcode::
+
+        import torch
+        import io
+
+        # Load LiteScriptModule from saved file path
+        torch.jit._load_for_lite_interpreter('lite_script_module.pt')
+
+        # Load LiteScriptModule from io.BytesIO object
+        with open('lite_script_module.pt', 'rb') as f:
+            buffer = io.BytesIO(f.read())
+
+        # Load all tensors to the original device
+        torch.jit.mobile._load_for_lite_interpreter(buffer)
+    """
+    if isinstance(f, (str, os.PathLike)):
+        if not os.path.exists(f):
+            raise ValueError(f"The provided filename {f} does not exist")
+        if os.path.isdir(f):
+            raise ValueError(f"The provided filename {f} is a directory")
+
+    map_location = validate_map_location(map_location)
+
+    if isinstance(f, (str, os.PathLike)):
+        cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
+    else:
+        cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
+            # pyrefly: ignore [missing-attribute]
+            f.read(),
+            map_location,
+        )
+
+    return LiteScriptModule(cpp_module)
+
+
+class LiteScriptModule:
+    def __init__(self, cpp_module) -> None:
+        self._c = cpp_module
+        super().__init__()
+
+    def __call__(self, *input):
+        return self._c.forward(input)
+
+    def find_method(self, method_name):
+        return self._c.find_method(method_name)
+
+    def forward(self, *input):
+        return self._c.forward(input)
+
+    def run_method(self, method_name, *input):
+        return self._c.run_method(method_name, input)
+
+
+def _export_operator_list(module: LiteScriptModule):
+    r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module."""
+    return torch._C._export_operator_list(module._c)
+
+
+def _get_model_bytecode_version(f_input) -> int:
+    r"""Take a file-like object to return an integer.
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        version: An integer. If the integer is -1, the version is invalid. A warning
+            will show in the log.
+
+    Example:
+    .. testcode::
+
+        from torch.jit.mobile import _get_model_bytecode_version
+
+        # Get bytecode version from a saved file path
+        version = _get_model_bytecode_version("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_model_bytecode_version(os.fspath(f_input))
+    else:
+        # pyrefly: ignore [missing-attribute]
+        return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
+
+
+def _get_mobile_model_contained_types(f_input) -> int:
+    r"""Take a file-like object and return a set of string, like ("int", "Optional").
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        type_list: A set of string, like ("int", "Optional"). These are types used in bytecode.
+
+    Example:
+
+    .. testcode::
+
+        from torch.jit.mobile import _get_mobile_model_contained_types
+
+        # Get type list from a saved file path
+        type_list = _get_mobile_model_contained_types("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
+    else:
+        # pyrefly: ignore [missing-attribute]
+        return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
+
+
+def _backport_for_mobile(f_input, f_output, to_version):
+    r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean.
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+        f_output: path to new model destination
+        to_version: the expected output model bytecode version
+    Returns:
+        success: A boolean. If backport success, return true, otherwise false
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if (isinstance(f_input, (str, os.PathLike))) and (
+        isinstance(f_output, (str, os.PathLike))
+    ):
+        return torch._C._backport_for_mobile(
+            os.fspath(f_input),
+            os.fspath(f_output),
+            to_version,
+        )
+    else:
+        return torch._C._backport_for_mobile_from_buffer(
+            # pyrefly: ignore [missing-attribute]
+            f_input.read(),
+            str(f_output),
+            to_version,
+        )
+
+
+def _backport_for_mobile_to_buffer(f_input, to_version):
+    r"""Take a string containing a file name (file-like object).
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
+    else:
+        return torch._C._backport_for_mobile_from_buffer_to_buffer(
+            # pyrefly: ignore [missing-attribute]
+            f_input.read(),
+            to_version,
+        )
+
+
+def _get_model_ops_and_info(f_input):
+    r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info.
+
+    These root operators can call other operators within them (traced ops), and
+    a root op can call many different traced ops depending on internal code paths in the root op.
+    These traced ops are not returned by this function. Those operators are abstracted into the
+    runtime as an implementation detail (and the traced ops themselves can also call other operators)
+    making retrieving them difficult and their value from this api negligible since they will differ
+    between which runtime version the model is run on. Because of this, there is a false positive this
+    api can't prevent in a compatibility usecase. All the root ops of a model are present in a
+    target runtime, but not all the traced ops are which prevents a model from being able to run.
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        Operators and info: A Dictionary mapping strings (the qualified names of the root operators)
+        of the model to their OperatorInfo structs.
+
+    Example:
+
+    .. testcode::
+
+        from torch.jit.mobile import _get_model_ops_and_info
+
+        # Get bytecode version from a saved file path
+        ops_and_info = _get_model_ops_and_info("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_model_ops_and_info(os.fspath(f_input))
+    else:
+        # pyrefly: ignore [missing-attribute]
+        return torch._C._get_model_ops_and_info(f_input.read())
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h
new file mode 100644
index 0000000000000000000000000000000000000000..e441ff5a28936d8ca999fcb61ddc8dbbb2c8c12b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+
+struct AllocInfo {
+  pid_t pid;
+  char free;
+  char filename[60];
+};
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/err.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/err.h
new file mode 100644
index 0000000000000000000000000000000000000000..e1e6aa4e277c3a94dd642ff2a27e6cd564322e46
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/err.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include 
+#include 
+
+// `errno` is only meaningful when it fails. E.g., a  successful `fork()` sets
+// `errno` to `EINVAL` in child process on some macos
+// (https://stackoverflow.com/a/20295079), and thus `errno` should really only
+// be inspected if an error occurred.
+//
+// All functions used in `libshm` (so far) indicate error by returning `-1`. If
+// you want to use a function with a different error reporting mechanism, you
+// need to port `SYSCHECK` from `torch/lib/c10d/Utils.hpp`.
+#define SYSCHECK_ERR_RETURN_NEG1(expr)                          \
+  while (true) {                                                \
+    if ((expr) == -1) {                                         \
+      if (errno == EINTR) {                                     \
+        continue;                                               \
+      } else {                                                  \
+        throw std::system_error(errno, std::system_category()); \
+      }                                                         \
+    } else {                                                    \
+      break;                                                    \
+    }                                                           \
+  }
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/libshm.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/libshm.h
new file mode 100644
index 0000000000000000000000000000000000000000..d3f7c7061abc9e56b7147fad7e85d1bcdacc61c8
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/libshm.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include 
+
+#ifdef __cplusplus
+
+void libshm_init(const char* manager_exec_path);
+
+// Superclass to run a constructor before at::RefcountedMapAllocator
+class THManagedMapAllocatorInit {
+ protected:
+  THManagedMapAllocatorInit(const char* manager_handle, const char* filename);
+  std::string manager_handle_;
+};
+
+// Like a at::RefcountedMapAllocator, but it also makes use of an external
+// shared memory manager process to ensure that shared memory regions actually
+// get freed in the end (even if processes lose the memory).
+class THManagedMapAllocator : private THManagedMapAllocatorInit,
+                              public at::RefcountedMapAllocator {
+ public:
+  THManagedMapAllocator(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+
+  void close() override;
+
+  ~THManagedMapAllocator() override {
+    close();
+  }
+
+  static at::DataPtr makeDataPtr(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+  static THManagedMapAllocator* fromDataPtr(const at::DataPtr& /*dptr*/);
+
+  const char* manager_handle() const {
+    return manager_handle_.c_str();
+  }
+};
+
+#endif
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/socket.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/socket.h
new file mode 100644
index 0000000000000000000000000000000000000000..e048098b94efac3360d4d72835d60b346fab4842
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm/socket.h
@@ -0,0 +1,164 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+class Socket {
+ public:
+  int socket_fd;
+  Socket(const Socket& other) = delete;
+
+ protected:
+  Socket() {
+    SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
+  }
+  Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
+    other.socket_fd = -1;
+  };
+  explicit Socket(int fd) : socket_fd(fd) {}
+
+  virtual ~Socket() {
+    if (socket_fd != -1)
+      close(socket_fd);
+  }
+
+  struct sockaddr_un prepare_address(const char* path) {
+    struct sockaddr_un address;
+    address.sun_family = AF_UNIX;
+    strcpy(address.sun_path, path);
+    return address;
+  }
+
+  // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html
+  size_t address_length(struct sockaddr_un address) {
+    return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
+  }
+
+  void recv(void* _buffer, size_t num_bytes) {
+    char* buffer = (char*)_buffer;
+    size_t bytes_received = 0;
+    ssize_t step_received;
+    struct pollfd pfd = {};
+    pfd.fd = socket_fd;
+    pfd.events = POLLIN;
+    while (bytes_received < num_bytes) {
+      SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
+      if (pfd.revents & POLLIN) {
+        SYSCHECK_ERR_RETURN_NEG1(
+            step_received =
+                ::read(socket_fd, buffer, num_bytes - bytes_received));
+        TORCH_CHECK(step_received != 0, "Other end has closed the connection");
+        bytes_received += step_received;
+        buffer += step_received;
+      } else if (pfd.revents & (POLLERR | POLLHUP)) {
+        TORCH_CHECK(false, "An error occurred while waiting for the data");
+      } else {
+        TORCH_CHECK(false, "Shared memory manager connection has timed out");
+      }
+    }
+  }
+
+  void send(const void* _buffer, size_t num_bytes) {
+    const char* buffer = (const char*)_buffer;
+    size_t bytes_sent = 0;
+    ssize_t step_sent;
+    while (bytes_sent < num_bytes) {
+      SYSCHECK_ERR_RETURN_NEG1(
+          step_sent = ::write(socket_fd, buffer, num_bytes));
+      bytes_sent += step_sent;
+      buffer += step_sent;
+    }
+  }
+};
+
+class ManagerSocket : public Socket {
+ public:
+  explicit ManagerSocket(int fd) : Socket(fd) {}
+
+  AllocInfo receive() {
+    AllocInfo info;
+    recv(&info, sizeof(info));
+    return info;
+  }
+
+  void confirm() {
+    send("OK", 2);
+  }
+};
+
+class ManagerServerSocket : public Socket {
+ public:
+  explicit ManagerServerSocket(const std::string& path) {
+    socket_path = path;
+    try {
+      struct sockaddr_un address = prepare_address(path.c_str());
+      size_t len = address_length(address);
+      SYSCHECK_ERR_RETURN_NEG1(
+          bind(socket_fd, (struct sockaddr*)&address, len));
+      SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
+    } catch (std::exception&) {
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
+      throw;
+    }
+  }
+
+  void remove() {
+    struct stat file_stat;
+    if (fstat(socket_fd, &file_stat) == 0)
+      SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
+  }
+
+  ~ManagerServerSocket() override {
+    unlink(socket_path.c_str());
+  }
+
+  ManagerSocket accept() {
+    int client_fd;
+    struct sockaddr_un addr;
+    socklen_t addr_len = sizeof(addr);
+    SYSCHECK_ERR_RETURN_NEG1(
+        client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
+    return ManagerSocket(client_fd);
+  }
+
+  std::string socket_path;
+};
+
+class ClientSocket : public Socket {
+ public:
+  explicit ClientSocket(const std::string& path) {
+    try {
+      struct sockaddr_un address = prepare_address(path.c_str());
+      size_t len = address_length(address);
+      SYSCHECK_ERR_RETURN_NEG1(
+          connect(socket_fd, (struct sockaddr*)&address, len));
+    } catch (std::exception&) {
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
+      throw;
+    }
+  }
+
+  void register_allocation(AllocInfo& info) {
+    char buffer[3] = {0, 0, 0};
+    send(&info, sizeof(info));
+    recv(buffer, 2);
+    TORCH_CHECK(
+        strcmp(buffer, "OK") == 0,
+        "Shared memory manager didn't respond with an OK");
+  }
+
+  void register_deallocation(AllocInfo& info) {
+    send(&info, sizeof(info));
+  }
+};
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h
new file mode 100644
index 0000000000000000000000000000000000000000..4dd193df93d110e3a04d33a3f9d3e3ec24948277
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include 
+
+#ifdef __cplusplus
+
+#ifdef SHM_EXPORTS
+#define SHM_API __declspec(dllexport)
+#else
+#define SHM_API __declspec(dllimport)
+#endif
+
+SHM_API void libshm_init(const char* manager_exec_path);
+
+class SHM_API THManagedMapAllocator : public at::RefcountedMapAllocator {
+ public:
+  THManagedMapAllocator(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size)
+      : at::RefcountedMapAllocator(filename, flags, size) {}
+
+  static at::DataPtr makeDataPtr(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+  static THManagedMapAllocator* fromDataPtr(const at::DataPtr&);
+
+  const char* manager_handle() const {
+    return "no_manager";
+  }
+};
+
+#endif
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..199878d69a387f8a76dd99dce26a5f169f60eedd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a9815e25ef36e299bd9b8cc21cc8b31b1861ba1a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bec50f89cd4da0f8f4ee4d6199574e4659d470e4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ef878d3c4b20ef38c7dfd6e14631e99b2fddcc1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from .binary import _apply_native_binary, _is_native_binary
+from .core import is_masked_tensor, MaskedTensor
+from .passthrough import _apply_pass_through_fn, _is_pass_through_fn
+from .reductions import _apply_reduction, _is_reduction
+from .unary import _apply_native_unary, _is_native_unary
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b461a70f919c89c05172e9174718b237d795c671
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9add0a1dfbae1f8dee18fecdcfd3f60da5231d7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py
@@ -0,0 +1,547 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from collections.abc import Callable
+from functools import partial
+from typing import Any, TYPE_CHECKING
+
+import torch
+
+from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
+from .core import (
+    _get_data,
+    _masks_match,
+    _maybe_get_mask,
+    is_masked_tensor,
+    MaskedTensor,
+)
+from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
+from .reductions import (
+    _apply_reduction,
+    NATIVE_REDUCE_FNS,
+    TENSOR_REDUCE_FNS,
+    TORCH_REDUCE_FNS,
+)
+from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
+
+
+if TYPE_CHECKING:
+    from torch._ops import OpOverload
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _check_args_kwargs_length(
+    args, kwargs, error_prefix, len_args=None, len_kwargs=None
+):
+    if len_args is not None and len_args != len(args):
+        raise ValueError(
+            f"{error_prefix}: len(args) must be {len_args} but got {len(args)}"
+        )
+    if len_kwargs is not None and len_kwargs != len(kwargs):
+        raise ValueError(
+            f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}"
+        )
+
+
+class _MaskedContiguous(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
+
+        if input.is_contiguous():
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+
+        return MaskedTensor(data.contiguous(), mask.contiguous())
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        return grad_output
+
+
+class _MaskedToDense(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
+
+        if input.layout == torch.strided:
+            return input
+
+        ctx.layout = input.layout
+        data = input.get_data()
+        mask = input.get_mask()
+
+        return MaskedTensor(data.to_dense(), mask.to_dense())
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        layout = ctx.layout
+
+        if layout == torch.sparse_coo:
+            return grad_output.to_sparse_coo()
+        elif layout == torch.sparse_csr:
+            return grad_output.to_sparse_csr()
+        elif layout == torch.strided:
+            return grad_output.to_dense()
+        raise ValueError("to_dense: Unsupported input layout: ", layout)
+
+
+class _MaskedToSparse(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
+
+        # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo
+        if input.layout == torch.sparse_coo:
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+        sparse_mask = mask.to_sparse_coo().coalesce()
+        sparse_data = data.sparse_mask(sparse_mask)
+
+        return MaskedTensor(sparse_data, sparse_mask)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        return grad_output.to_dense()
+
+
+class _MaskedToSparseCsr(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
+
+        if input._masked_data.ndim != 2:
+            raise ValueError(
+                f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}"
+            )
+
+        if input.layout == torch.sparse_csr:
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+        sparse_mask = mask.to_sparse_csr()
+        sparse_data = data.sparse_mask(sparse_mask)
+
+        return MaskedTensor(sparse_data, sparse_mask)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        return grad_output.to_dense()
+
+
+class _MaskedWhere(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, cond, self, other):
+        ctx.mark_non_differentiable(cond)
+        ctx.save_for_backward(cond)
+        return torch.ops.aten.where(cond, self, other)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        (cond,) = ctx.saved_tensors
+
+        def masked_out_like(mt):
+            return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool())
+
+        return (
+            None,
+            torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)),
+            torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output),
+        )
+
+
+_MASKEDTENSOR_FUNCTION_TABLE = {}
+
+_function_fn_apply_map = {
+    (
+        tuple(NATIVE_REDUCE_FNS),
+        tuple(TORCH_REDUCE_FNS),
+        tuple(TENSOR_REDUCE_FNS),
+    ): _apply_reduction,
+}
+
+for fn_map_list, apply_fn in _function_fn_apply_map.items():
+    for fn_map in fn_map_list:
+        for fn in fn_map:
+            _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn)
+
+
+def register_function_func(ops):
+    """
+    Used for registering a new __torch_function__ function to MaskedTensor
+    Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
+
+    The code to register a new function looks like:
+
+    @register_function_func(list_of_ops)
+    def foo(func, *args, **kwargs):
+        
+    """
+
+    def wrapper(func):
+        for op in ops:
+            _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
+
+    return wrapper
+
+
+@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
+def _general_function_reductions(func, *args, **kwargs):
+    return _apply_reduction(func, *args, **kwargs)
+
+
+@register_function_func([torch.Tensor.where, torch.where])
+def _function_where(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0
+    )
+    return _MaskedWhere.apply(*args)
+
+
+@register_function_func([torch.Tensor.contiguous])
+def _function_contiguous(func, *args, **kwargs):
+    return _MaskedContiguous.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_dense])
+def _function_to_dense(func, *args, **kwargs):
+    return _MaskedToDense.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_sparse])
+def _function_to_sparse(func, *args, **kwargs):
+    return _MaskedToSparse.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_sparse_csr])
+def _function_to_sparse_csr(func, *args, **kwargs):
+    return _MaskedToSparseCsr.apply(args[0])
+
+
+_MASKEDTENSOR_DISPATCH_TABLE: dict["OpOverload", Callable[..., Any]] = {}
+
+
+def register_dispatch_func(aten_ops):
+    """
+    Used for registering a new __torch_dispatch__ function to MaskedTensor
+    Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
+
+    The code to register a new function looks like:
+
+    @register_dispatch_func(list_of_ops)
+    def foo(func, *args, **kwargs):
+        
+    """
+
+    def wrapper(func):
+        for aten_op in aten_ops:
+            _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
+
+    return wrapper
+
+
+@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
+def _general_reduction(func, *args, **kwargs):
+    return _apply_reduction(func, *args, **kwargs)
+
+
+@register_dispatch_func(PASSTHROUGH_FNS)
+def _general_passthrough(func, *args, **kwargs):
+    return _apply_pass_through_fn(func, *args, **kwargs)
+
+
+@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS)
+def _general_unary(func, *args, **kwargs):
+    return _apply_native_unary(func, *args, **kwargs)
+
+
+@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS)
+def _general_binary(func, *args, **kwargs):
+    return _apply_native_binary(func, *args, **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.stride])
+def stride(func, *args, **kwargs):
+    return None
+
+
+@register_dispatch_func([torch.ops.aten.sym_stride])
+def sym_stride(func, *args, **kwargs):
+    return None
+
+
+@register_dispatch_func([torch.ops.prim.layout])
+def layout(func, *args, **kwargs):
+    return _get_data(args[0]).layout
+
+
+@register_dispatch_func(
+    [torch.ops.aten.is_contiguous, torch.ops.aten.sym_is_contiguous]
+)
+def is_contiguous(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError("MaskedTensors with sparse data do not have is_contiguous")
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.is_strides_like_format])
+def is_strides_like_format(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError(
+            "MaskedTensors with sparse data do not have is_strides_like_format"
+        )
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
+def is_non_overlapping_and_dense(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError(
+            "MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
+        )
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.contiguous])
+def contiguous(func, *args, **kwargs):
+    if _get_data(args[0]).is_sparse:
+        raise ValueError("MaskedTensors with sparse data do not have contiguous")
+    return _MaskedContiguous.apply(args[0])
+
+
+@register_dispatch_func([torch.ops.aten.new_empty_strided])
+def new_empty_strided(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3)
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    if tuple(args[1]) != tuple(data.size()):
+        raise ValueError(
+            f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()"
+        )
+    if tuple(args[2]) != tuple(data.stride()):
+        raise ValueError(
+            f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()"
+        )
+    return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
+
+
+@register_dispatch_func([torch.ops.aten._local_scalar_dense])
+def _local_scalar_dense(func, *args, **kwargs):
+    if not _maybe_get_mask(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor")
+    return torch.ops.aten._local_scalar_dense(_get_data(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone])
+def _apply_fn_on_data(func, *args, **kwargs):
+    return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten._to_copy])
+def _to_copy(func, *args, **kwargs):
+    new_data = func(_get_data(args[0]), *args[1:], **kwargs)
+    cloned_kwargs = kwargs.copy()
+    cloned_kwargs["dtype"] = torch.bool
+    new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._softmax])
+def _softmax(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
+    )
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
+    return MaskedTensor(result_data, mask)
+
+
+@register_dispatch_func([torch.ops.aten.ones_like])
+def ones_like(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
+    result_data = func(_get_data(args[0]), **kwargs)
+    return MaskedTensor(result_data, _maybe_get_mask(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten._softmax_backward_data])
+def _softmax_backward_data(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
+    grad, output, dim, _input_dtype = args
+    if is_masked_tensor(grad) and is_masked_tensor(output):
+        if not _masks_match(grad, output):
+            raise ValueError(
+                f"__torch_dispatch__, {func}: expected the masks of grad and output to match"
+            )
+        grad_data = _get_data(grad)
+        new_grad_data = torch.ops.aten._masked_softmax_backward(
+            grad_data,
+            _get_data(output),
+            ~_maybe_get_mask(grad),
+            dim % grad_data.ndim,
+        )
+        res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
+        return res
+    else:
+        raise ValueError(
+            f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors"
+        )
+
+
+@register_dispatch_func([torch.ops.aten.copy_])
+def copy_(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
+    if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])):
+        raise ValueError("args[0] mask and args[1] mask must match but do not")
+    func(_get_data(args[0]), _get_data(args[1]))
+    return args[0]
+
+
+@register_dispatch_func([torch.ops.aten.where])
+def where(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mx = args[1]
+    my = args[2]
+    if not is_masked_tensor(mx):
+        mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool))
+    if not is_masked_tensor(my):
+        my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool))
+    new_data = func(args[0], mx.get_data(), my.get_data())
+    new_mask = func(args[0], mx.get_mask(), my.get_mask())
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_sparse])
+def _to_sparse(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise TypeError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool))
+    if mt.is_sparse_coo():
+        return mt
+    new_mask = func(_maybe_get_mask(args[0])).coalesce()
+    new_data = _get_data(args[0]).sparse_mask(new_mask)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_sparse_csr])
+def _to_sparse_csr(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
+    if mt.is_sparse_csr():
+        return mt
+    new_mask = func(_maybe_get_mask(args[0]))
+    new_data = _get_data(args[0]).sparse_mask(new_mask)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_dense])
+def _to_dense(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
+    new_data = func(_get_data(args[0]))
+    new_mask = func(_maybe_get_mask(args[0]))
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._indices])
+def _indices(func, *args, **kwargs):
+    # Assumes data is sparse
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0]).indices()
+    return MaskedTensor(data, torch.ones_like(data).bool())
+
+
+@register_dispatch_func([torch.ops.aten._values])
+def _values(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0]).values()
+    return MaskedTensor(data, torch.ones_like(data).bool())
+
+
+@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors])
+def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs):
+    new_args = list(args)
+    if is_masked_tensor(args[-1]):
+        new_args[-1] = args[-1].get_data()
+    if is_masked_tensor(args[-2]):
+        new_args[-2] = args[-2].get_data()
+
+    new_data = func(*new_args, **kwargs)
+    new_args[-1] = torch.ones_like(new_args[-1])
+    new_mask = func(*new_args, **kwargs).bool()
+
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten.is_same_size])
+def is_same_size(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
+    return _get_data(args[0]).is_same_size(_get_data(args[1]))
+
+
+@register_dispatch_func([torch.ops.aten._is_any_true])
+def _is_any_true(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    if mask is None:
+        raise ValueError(
+            f"__torch_dispatch__, {func}: expected args[0] to be a MaskedTensor"
+        )
+    if data.dtype != torch.bool:
+        raise ValueError(f"__torch_dispatch__, {func}: expected a boolean tensor")
+    if data.is_sparse:
+        raise ValueError(f"MaskedTensors with sparse data do not have {func}")
+
+    return MaskedTensor(func(data & mask), torch.tensor(True))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..8315ae11be7175c2b5aaef178a4bc4785dcbcb29
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py
@@ -0,0 +1,200 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import torch
+
+from .core import (
+    _map_mt_args_kwargs,
+    _masks_match,
+    _tensors_match,
+    _wrap_result,
+    is_masked_tensor,
+)
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+BINARY_NAMES = [
+    "add",
+    "atan2",
+    "arctan2",
+    "bitwise_and",
+    "bitwise_or",
+    "bitwise_xor",
+    "bitwise_left_shift",
+    "bitwise_right_shift",
+    "div",
+    "divide",
+    "floor_divide",
+    "fmod",
+    "logaddexp",
+    "logaddexp2",
+    "mul",
+    "multiply",
+    "nextafter",
+    "remainder",
+    "sub",
+    "subtract",
+    "true_divide",
+    "eq",
+    "ne",
+    "le",
+    "ge",
+    "greater",
+    "greater_equal",
+    "gt",
+    "less_equal",
+    "lt",
+    "less",
+    "maximum",
+    "minimum",
+    "fmax",
+    "fmin",
+    "not_equal",
+]
+
+INPLACE_BINARY_NAMES = [
+    n + "_"
+    for n in (
+        list(
+            set(BINARY_NAMES)
+            - {
+                "logaddexp",
+                "logaddexp2",
+                "equal",
+                "fmin",
+                "minimum",
+                "maximum",
+                "fmax",
+            }
+        )
+    )
+]
+
+
+def _get_at_least_one_mask(a, b):
+    if not is_masked_tensor(a) and not is_masked_tensor(b):
+        raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
+    if not _masks_match(a, b):
+        raise ValueError("a and b must have matching masks")
+    if is_masked_tensor(a):
+        return a.get_mask()
+    return b.get_mask()
+
+
+def _binary_helper(fn, args, kwargs, inplace):
+    if len(kwargs) != 0:
+        raise ValueError("len(kwargs) must equal 0")
+    for a in args[2:]:
+        if torch.is_tensor(a):
+            raise TypeError(
+                "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
+            )
+
+    if not _masks_match(*args[:2]):
+        raise ValueError(
+            "Input masks must match. If you need support for this, please open an issue on Github."
+        )
+
+    data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
+    mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
+
+    args0_layout = data_args[0].layout
+    same_layout = (
+        torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
+    ) and (args0_layout == data_args[1].layout)
+
+    if args0_layout == torch.sparse_coo:
+        if same_layout:
+            if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
+                raise ValueError(
+                    "sparse_coo indices must match. If you need support for this, please open an issue on Github."
+                )
+            if data_args[0].size() != data_args[1].size():
+                raise ValueError(
+                    "input1 and input2 must have the same size for binary functions."
+                )
+
+            data_args[1] = data_args[1].values()
+
+        i = data_args[0].indices()
+        size = data_args[0].size()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_coo_tensor(i, v, size)
+
+    elif args0_layout == torch.sparse_csr:
+        if same_layout:
+            if not (
+                _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
+                and _tensors_match(
+                    data_args[0].col_indices(), data_args[1].col_indices()
+                )
+            ):
+                raise ValueError(
+                    "sparse_csr indices must match. If you need support for this, please open an issue on Github."
+                )
+
+            data_args[1] = data_args[1].values()
+
+        crow = data_args[0].crow_indices()
+        col = data_args[0].col_indices()
+        size = data_args[0].size()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_csr_tensor(crow, col, v, size)
+
+    else:
+        result_data = fn(*data_args)
+
+    if inplace:
+        args[0]._set_data_mask(result_data, mask_args[0])
+        return args[0]
+    else:
+        result_mask = _get_at_least_one_mask(*args[:2])
+        # sparse tensors don't have strides so we can only expand if the layout is strided
+        if args0_layout == torch.strided:
+            result_mask = result_mask.expand_as(result_data)
+        return _wrap_result(result_data, result_mask)
+
+
+def _torch_binary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def binary_fn(*args, **kwargs):
+        return _binary_helper(fn, args, kwargs, inplace=False)
+
+    return binary_fn
+
+
+def _torch_inplace_binary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def binary_fn(*args, **kwargs):
+        return _binary_helper(fn, args, kwargs, inplace=True)
+
+    return binary_fn
+
+
+NATIVE_BINARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
+}
+NATIVE_INPLACE_BINARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_inplace_binary(name)
+    for name in INPLACE_BINARY_NAMES
+}
+
+NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
+NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
+
+
+def _is_native_binary(fn):
+    return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
+
+
+def _apply_native_binary(fn, *args, **kwargs):
+    if fn in NATIVE_BINARY_FNS:
+        return NATIVE_BINARY_MAP[fn](*args, **kwargs)
+    if fn in NATIVE_INPLACE_BINARY_FNS:
+        return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..cad5621b29bd663ef4462f1be6c8f8f2c4762c2d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py
@@ -0,0 +1,364 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import warnings
+from typing import Any
+from typing_extensions import TypeIs
+
+import torch
+from torch.overrides import get_default_nowrap_functions
+
+
+__all__ = [
+    "MaskedTensor",
+    "is_masked_tensor",
+]
+
+
+def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
+    r"""Returns True if the input is a MaskedTensor, else False
+
+    Args:
+        a: any input
+
+    Examples:
+
+        >>> # xdoctest: +SKIP
+        >>> from torch.masked import MaskedTensor
+        >>> data = torch.arange(6).reshape(2, 3)
+        >>> mask = torch.tensor([[True, False, False], [True, True, False]])
+        >>> mt = MaskedTensor(data, mask)
+        >>> is_masked_tensor(mt)
+        True
+    """
+    return isinstance(obj, MaskedTensor)
+
+
+def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
+    if is_masked_tensor(a) or is_masked_tensor(b):
+        raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
+    if a.layout != b.layout:
+        raise ValueError(
+            f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}"
+        )
+
+    if a.dtype != b.dtype:
+        b = b.type(a.dtype)
+    if a.layout == b.layout == torch.sparse_coo:
+        return _tensors_match(a.values(), b.values(), exact) and _tensors_match(
+            a.indices(), b.indices(), exact
+        )
+    elif a.layout == b.layout == torch.sparse_csr:
+        return (
+            _tensors_match(a.crow_indices(), b.crow_indices(), exact)
+            and _tensors_match(a.col_indices(), b.col_indices(), exact)
+            and _tensors_match(a.values(), b.values(), exact)
+        )
+    if exact:
+        return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
+    return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)
+
+
+def _masks_match(a, b):
+    if is_masked_tensor(a) and is_masked_tensor(b):
+        mask_a = a.get_mask()
+        mask_b = b.get_mask()
+        return _tensors_match(mask_a, mask_b, exact=True)
+    return True
+
+
+def _map_mt_args_kwargs(args, kwargs, map_fn):
+    def _helper(a, map_fn):
+        if is_masked_tensor(a):
+            return map_fn(a)
+        elif torch.is_tensor(a):
+            return a
+        elif isinstance(a, list):
+            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
+            return a_impl
+        elif isinstance(a, tuple):
+            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
+            return tuple(a_impl)
+        else:
+            return a
+
+    if kwargs is None:
+        kwargs = {}
+    impl_args = []
+    for a in args:
+        impl_args.append(_helper(a, map_fn))
+    impl_kwargs = {}
+    for k in kwargs:
+        impl_kwargs[k] = _helper(a, map_fn)
+    return impl_args, impl_kwargs
+
+
+def _wrap_result(result_data, result_mask):
+    if isinstance(result_data, list):
+        return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
+    if isinstance(result_data, tuple):
+        return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
+    if torch.is_tensor(result_data):
+        return MaskedTensor(result_data, result_mask)
+    # Expect result_data and result_mask to be Tensors only
+    return NotImplemented
+
+
+def _masked_tensor_str(data, mask, formatter):
+    if data.layout in {torch.sparse_coo, torch.sparse_csr}:
+        data = data.to_dense()
+        mask = mask.to_dense()
+    if data.dim() == 1:
+        formatted_elements = [
+            formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
+            for d in data
+        ]
+        max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
+        return (
+            "["
+            + ", ".join(
+                [
+                    "--".rjust(max_len) if m else e
+                    for (e, m) in zip(formatted_elements, ~mask)
+                ]
+            )
+            + "]"
+        )
+    sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)]
+    sub_strings = ["\n".join(["  " + si for si in s.split("\n")]) for s in sub_strings]
+    return "[\n" + ",\n".join(sub_strings) + "\n]"
+
+
+def _get_data(a):
+    if is_masked_tensor(a):
+        return a._masked_data
+    return a
+
+
+def _maybe_get_mask(a):
+    if is_masked_tensor(a):
+        return a.get_mask()
+    return None
+
+
+class MaskedTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, data, mask, requires_grad=False):
+        if is_masked_tensor(data) or not torch.is_tensor(data):
+            raise TypeError("data must be a Tensor")
+        if is_masked_tensor(mask) or not torch.is_tensor(mask):
+            raise TypeError("mask must be a Tensor")
+        # Use a Tensor that of the give size for the wrapper.
+        kwargs = {
+            "device": data.device,
+            "dtype": data.dtype,
+            "layout": data.layout,
+            "requires_grad": requires_grad,
+            "dispatch_sizes_strides_policy": "strides",
+            "dispatch_layout": True,
+        }
+        warnings.warn(
+            (
+                "The PyTorch API of MaskedTensors is in prototype stage "
+                "and will change in the near future. Please open a Github issue "
+                "for features requests and see our documentation on the torch.masked "
+                "module for further information about the project."
+            ),
+            UserWarning,
+            stacklevel=2,
+        )
+        if data.requires_grad:
+            warnings.warn(
+                "It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
+                "To avoid this, you can use data.detach().clone()",
+                UserWarning,
+                stacklevel=2,
+            )
+        # pyrefly: ignore [bad-argument-type]
+        return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
+
+    def _preprocess_data(self, data, mask):
+        from .._ops import _sparse_coo_where, _sparse_csr_where
+
+        if data.layout != mask.layout:
+            raise TypeError("data and mask must have the same layout.")
+        if data.layout == torch.sparse_coo:
+            data = data.coalesce()
+            mask = mask.coalesce()
+            if data._nnz() != mask._nnz():
+                data = _sparse_coo_where(mask, data, torch.tensor(0))
+        elif data.layout == torch.sparse_csr:
+            if data._nnz() != mask._nnz():
+                data = _sparse_csr_where(mask, data, torch.tensor(0))
+
+        # Have to pick awkward names to not conflict with existing fields such as data
+        self._masked_data = data.clone()
+        self._masked_mask = mask.clone()
+
+    def _validate_members(self):
+        data = self._masked_data
+        mask = self.get_mask()
+        if type(data) is not type(mask):
+            raise TypeError(
+                f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
+            )
+        if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
+            raise TypeError(f"data layout of {data.layout} is not supported.")
+        if data.layout == torch.sparse_coo:
+            if not _tensors_match(data.indices(), mask.indices(), exact=True):
+                raise ValueError(
+                    "data and mask are both sparse COO tensors but do not have the same indices."
+                )
+        elif data.layout == torch.sparse_csr:
+            if not _tensors_match(
+                data.crow_indices(), mask.crow_indices(), exact=True
+            ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
+                raise ValueError(
+                    "data and mask are both sparse CSR tensors but do not share either crow or col indices."
+                )
+        if mask.dtype != torch.bool:
+            raise TypeError("mask must have dtype bool.")
+        if not (
+            data.dtype == torch.float16
+            or data.dtype == torch.float32
+            or data.dtype == torch.float64
+            or data.dtype == torch.bool
+            or data.dtype == torch.int8
+            or data.dtype == torch.int16
+            or data.dtype == torch.int32
+            or data.dtype == torch.int64
+        ):
+            raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
+        if data.dim() != mask.dim():
+            raise ValueError("data.dim() must equal mask.dim()")
+        if data.size() != mask.size():
+            raise ValueError("data.size() must equal mask.size()")
+
+    def __init__(self, data, mask, requires_grad=False):
+        self._preprocess_data(data, mask)
+        self._validate_members()
+
+    @staticmethod
+    def _from_values(data, mask):
+        """Differentiable constructor for MaskedTensor"""
+
+        class Constructor(torch.autograd.Function):
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def forward(ctx, data, mask):
+                return MaskedTensor(data, mask)
+
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def backward(ctx, grad_output):
+                return grad_output, None
+
+        result = Constructor.apply(data, mask)
+        return result
+
+    def _set_data_mask(self, data, mask):
+        self._masked_data = data
+        self._masked_mask = mask
+        self._validate_members()
+
+    def __repr__(self):  # type: ignore[override]
+        formatter = "{0:8.4f}"
+        if self.dim() == 0:
+            scalar_data = self.get_data().item()
+            data_formatted = (
+                formatter.format(scalar_data)
+                if isinstance(scalar_data, float)
+                else str(scalar_data)
+            )
+            if not self.get_mask().item():
+                data_formatted = "--"
+            return (
+                "MaskedTensor("
+                + data_formatted
+                + ", "
+                + str(self.get_mask().item())
+                + ")"
+            )
+        s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter)
+        s = "\n".join("  " + si for si in s.split("\n"))
+        return "MaskedTensor(\n" + s + "\n)"
+
+    # Seems like this needs to be defined before torch_dispatch to work
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+
+        from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
+
+        if func in _MASKEDTENSOR_FUNCTION_TABLE:
+            return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
+
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+        with torch._C.DisableTorchFunctionSubclass():
+            ret = func(*args, **kwargs)
+            if func in get_default_nowrap_functions():
+                return ret
+            else:
+                return torch._tensor._convert(ret, cls)
+
+    @classmethod
+    def unary(cls, fn, data, mask):
+        return MaskedTensor(fn(data), mask)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):  # type: ignore[override]
+        func = func.overloadpacket
+
+        from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
+
+        if func in _MASKEDTENSOR_DISPATCH_TABLE:
+            return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
+
+        msg = (
+            f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
+            "If you would like this operator to be supported, please file an issue for a feature request at "
+            "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
+            "In the case that the semantics for the operator are not trivial, it would be appreciated "
+            "to also include a proposal for the semantics."
+        )
+        warnings.warn(msg, stacklevel=2)
+        return NotImplemented
+
+    def __lt__(self, other):
+        if is_masked_tensor(other):
+            return MaskedTensor(self.get_data() < _get_data(other), self.get_mask())
+        return MaskedTensor(self.get_data() < other, self.get_mask())
+
+    def to_tensor(self, value):
+        return self.get_data().masked_fill(~self.get_mask(), value)
+
+    def get_data(self):
+        class GetData(torch.autograd.Function):
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def forward(ctx, self):
+                return self._masked_data.detach()
+
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def backward(ctx, grad_output):
+                if is_masked_tensor(grad_output):
+                    return grad_output
+                return MaskedTensor(grad_output, self.get_mask())
+
+        return GetData.apply(self)
+
+    def get_mask(self):
+        return self._masked_mask
+
+    def is_sparse_coo(self):
+        return self.layout == torch.sparse_coo
+
+    def is_sparse_csr(self):  # type: ignore[override]
+        return self.layout == torch.sparse_csr
+
+    # Update later to support more sparse layouts
+    @property
+    def is_sparse(self):  # type: ignore[override]
+        return self.is_sparse_coo() or self.is_sparse_csr()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py
new file mode 100644
index 0000000000000000000000000000000000000000..35c8e3d2aa9438dbcfc7995a1cdcd3c5cc8dc1fc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py
@@ -0,0 +1,24 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from .core import MaskedTensor
+
+
+__all__ = [
+    "as_masked_tensor",
+    "masked_tensor",
+]
+
+
+# These two factory functions are intended to mirror
+#     torch.tensor - guaranteed to be a leaf node
+#     torch.as_tensor - differentiable constructor that preserves the autograd history
+
+
+def masked_tensor(
+    data: object, mask: object, requires_grad: bool = False
+) -> MaskedTensor:
+    return MaskedTensor(data, mask, requires_grad)
+
+
+def as_masked_tensor(data: object, mask: object) -> MaskedTensor:
+    return MaskedTensor._from_values(data, mask)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba13f50c1fee9c9fc10563ffc9f4ff3211c0dca6
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py
@@ -0,0 +1,50 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+"""
+These are functions that should simply be applied to both mask and data.
+Take select or stack as an example. This operation can be applied to
+both the mask and data of a MaskedTensor and the result wrapped into
+a new MaskedTensor as a result.
+"""
+
+import torch
+
+from .core import _map_mt_args_kwargs, _wrap_result
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+PASSTHROUGH_FNS = [
+    torch.ops.aten.select,
+    torch.ops.aten.transpose,
+    torch.ops.aten.split,
+    torch.ops.aten.t,
+    torch.ops.aten.slice,
+    torch.ops.aten.slice_backward,
+    torch.ops.aten.select_backward,
+    torch.ops.aten.index,
+    torch.ops.aten.expand,
+    torch.ops.aten.view,
+    torch.ops.aten._unsafe_view,
+    torch.ops.aten._reshape_alias,
+    torch.ops.aten.cat,
+    torch.ops.aten.unsqueeze,
+    torch.ops.aten.unfold,
+    torch.ops.aten.unfold_backward,
+    torch.ops.aten.im2col,
+    torch.ops.aten.col2im,
+    torch.ops.aten.stack,
+]
+
+
+def _is_pass_through_fn(fn):
+    return fn in PASSTHROUGH_FNS
+
+
+def _apply_pass_through_fn(fn, *args, **kwargs):
+    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
+    result_data = fn(*data_args, **data_kwargs)
+    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
+    result_mask = fn(*mask_args, **mask_kwargs)
+    return _wrap_result(result_data, result_mask)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py
new file mode 100644
index 0000000000000000000000000000000000000000..6acc8415267bb9fdd7fe6af707cfbbaa74869184
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py
@@ -0,0 +1,176 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import warnings
+
+import torch
+
+from .core import is_masked_tensor
+from .creation import as_masked_tensor, masked_tensor
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _masked_all_all(data, mask=None):
+    if mask is None:
+        return data.all()
+    return data.masked_fill(~mask, True).all()
+
+
+def _masked_all_dim(data, dim, keepdim=False, mask=None):
+    if mask is None:
+        return torch.all(data, dim=dim, keepdim=keepdim)
+    return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
+
+
+def _masked_all(*args, **kwargs):
+    if len(args) == 1 and len(kwargs) == 1:
+        return _masked_all_all(args[0], mask=kwargs["mask"])
+    return _masked_all_dim(*args, **kwargs)
+
+
+def _multidim_any(mask, dim, keepdim):
+    if isinstance(dim, int):
+        return _multidim_any(mask, [dim], keepdim)
+    for d in sorted(dim, reverse=True):
+        mask = torch.any(mask, dim=d, keepdim=keepdim)
+    return mask
+
+
+def _get_masked_fn(fn):
+    if fn == "all":
+        return _masked_all
+    return getattr(torch.masked, fn)
+
+
+def _torch_reduce_all(fn):
+    def reduce_all(self):
+        masked_fn = _get_masked_fn(fn)
+        data = self.get_data()
+        mask = self.get_mask().values() if self.is_sparse else self.get_mask()
+        # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
+        # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
+        # Therefore, this implementation calculates it using the strides.
+        if fn == "all":
+            result_data = masked_fn(data, mask=mask)
+
+        elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
+            sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
+            indices = (
+                data.to_sparse_coo().indices()
+                if not self.is_sparse_coo()
+                else data.indices()
+            )
+            idx = indices.unbind(1)[sparse_idx]
+            stride = data.size().numel() / torch.tensor(
+                data.size(), device=data.device
+            ).cumprod(0)
+            result_data = torch.sum(idx * stride)
+
+        # we simply pass in the values for sparse COO/CSR tensors
+        elif self.is_sparse:
+            result_data = masked_fn(masked_tensor(data.values(), mask))
+
+        else:
+            result_data = masked_fn(self, mask=mask)
+
+        return as_masked_tensor(result_data, torch.any(mask))
+
+    return reduce_all
+
+
+def _torch_reduce_dim(fn):
+    def reduce_dim(self, dim, keepdim=False, dtype=None):
+        if self.is_sparse:
+            msg = (
+                f"The sparse version of {fn} is not implemented in reductions.\n"
+                "If you would like this operator to be supported, please file an issue for a feature request at "
+                "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
+                "In the case that the semantics for the operator are not trivial, it would be appreciated "
+                "to also include a proposal for the semantics."
+            )
+            warnings.warn(msg, stacklevel=2)
+            return NotImplemented
+        if not is_masked_tensor(self):
+            raise TypeError("Input to reduce_dim must be a MaskedTensor")
+
+        masked_fn = _get_masked_fn(fn)
+        data = self.get_data()
+        mask = self.get_mask()
+        if fn == "all":
+            result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
+        else:
+            result_data = masked_fn(
+                self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
+            )
+        return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
+
+    return reduce_dim
+
+
+def _torch_reduce(fn):
+    def reduce_fn(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0:
+            return _torch_reduce_all(fn)(args[0])
+        return _torch_reduce_dim(fn)(*args, **kwargs)
+
+    return reduce_fn
+
+
+def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
+    return input, dim, keepdim, dtype
+
+
+def _torch_grad_reduce(fn):
+    def grad_reduce(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0:
+            return _torch_reduce_all(fn)(args[0])
+        # TODO: autograd.Function doesn't support kwarg
+        input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
+        return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
+
+    return grad_reduce
+
+
+REDUCE_NAMES = [
+    "sum",
+    "mean",
+    "amin",
+    "amax",
+    "argmin",
+    "argmax",
+    "prod",
+    "all",
+    "norm",
+    "var",
+    "std",
+]
+
+NATIVE_REDUCE_MAP = {
+    getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
+}
+TORCH_REDUCE_MAP = {
+    getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
+}
+TENSOR_REDUCE_MAP = {
+    getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
+}
+
+NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
+TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
+TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
+
+
+def _is_reduction(fn):
+    return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
+
+
+def _apply_reduction(fn, *args, **kwargs):
+    if fn in NATIVE_REDUCE_MAP:
+        return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
+    if fn in TORCH_REDUCE_MAP:
+        return TORCH_REDUCE_MAP[fn](*args, **kwargs)
+    if fn in TENSOR_REDUCE_MAP:
+        return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e04ee6e810a7418829b68323097612391017b14e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py
@@ -0,0 +1,194 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import torch
+
+from .core import _map_mt_args_kwargs, _wrap_result
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+UNARY_NAMES = [
+    "abs",
+    "absolute",
+    "acos",
+    "arccos",
+    "acosh",
+    "arccosh",
+    "angle",
+    "asin",
+    "arcsin",
+    "asinh",
+    "arcsinh",
+    "atan",
+    "arctan",
+    "atanh",
+    "arctanh",
+    "bitwise_not",
+    "ceil",
+    "clamp",
+    "clip",
+    "conj_physical",
+    "cos",
+    "cosh",
+    "deg2rad",
+    "digamma",
+    "erf",
+    "erfc",
+    "erfinv",
+    "exp",
+    "exp2",
+    "expm1",
+    "fix",
+    "floor",
+    "frac",
+    "lgamma",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "logit",
+    "i0",
+    "isnan",
+    "nan_to_num",
+    "neg",
+    "negative",
+    "positive",
+    "pow",
+    "rad2deg",
+    "reciprocal",
+    "round",
+    "rsqrt",
+    "sigmoid",
+    "sign",
+    "sgn",
+    "signbit",
+    "sin",
+    "sinc",
+    "sinh",
+    "sqrt",
+    "square",
+    "tan",
+    "tanh",
+    "trunc",
+]
+
+INPLACE_UNARY_NAMES = [
+    n + "_"
+    for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
+]
+
+# Explicitly tracking functions we know are currently not supported
+# This might be due to missing code gen or because of complex semantics
+UNARY_NAMES_UNSUPPORTED = [
+    "atan2",
+    "arctan2",
+    "bitwise_left_shift",
+    "bitwise_right_shift",
+    "copysign",
+    "float_power",
+    "fmod",
+    "frexp",
+    "gradient",
+    "imag",
+    "ldexp",
+    "lerp",
+    "logical_not",
+    "hypot",
+    "igamma",
+    "igammac",
+    "mvlgamma",
+    "nextafter",
+    "polygamma",
+    "real",
+    "remainder",
+    "true_divide",
+    "xlogy",
+]
+
+
+def _unary_helper(fn, args, kwargs, inplace):
+    if len(kwargs) != 0:
+        raise ValueError(
+            "MaskedTensor unary ops require that len(kwargs) == 0. "
+            "If you need support for this, please open an issue on Github."
+        )
+    for a in args[1:]:
+        if torch.is_tensor(a):
+            raise TypeError(
+                "MaskedTensor unary ops do not support additional Tensor arguments"
+            )
+
+    mask_args, _mask_kwargs = _map_mt_args_kwargs(
+        args, kwargs, lambda x: x._masked_mask
+    )
+    data_args, _data_kwargs = _map_mt_args_kwargs(
+        args, kwargs, lambda x: x._masked_data
+    )
+
+    if args[0].layout == torch.sparse_coo:
+        data_args[0] = data_args[0].coalesce()
+        s = data_args[0].size()
+        i = data_args[0].indices()
+        data_args[0] = data_args[0].coalesce().values()
+        v = fn(*data_args)
+        result_data = torch.sparse_coo_tensor(i, v, size=s)
+
+    elif args[0].layout == torch.sparse_csr:
+        crow = data_args[0].crow_indices()
+        col = data_args[0].col_indices()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_csr_tensor(crow, col, v)
+
+    else:
+        result_data = fn(*data_args)
+
+    if inplace:
+        args[0]._set_data_mask(result_data, mask_args[0])
+        return args[0]
+    else:
+        return _wrap_result(result_data, mask_args[0])
+
+
+def _torch_unary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def unary_fn(*args, **kwargs):
+        return _unary_helper(fn, args, kwargs, inplace=False)
+
+    return unary_fn
+
+
+def _torch_inplace_unary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def unary_fn(*args, **kwargs):
+        return _unary_helper(fn, args, kwargs, inplace=True)
+
+    return unary_fn
+
+
+NATIVE_UNARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
+}
+NATIVE_INPLACE_UNARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_inplace_unary(name)
+    for name in INPLACE_UNARY_NAMES
+}
+
+NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
+NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
+
+
+def _is_native_unary(fn):
+    return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
+
+
+def _apply_native_unary(fn, *args, **kwargs):
+    if fn in NATIVE_UNARY_FNS:
+        return NATIVE_UNARY_MAP[fn](*args, **kwargs)
+    if fn in NATIVE_INPLACE_UNARY_FNS:
+        return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4aee296d392e53e318b0cc690b58a3e1909c0c9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c33661d78e0aa8968b9e4c9d19a85d295e846a3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/event.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/event.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93317e5e5db9df19c8fd8d41c2577311d79005a0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/event.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/profiler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/profiler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88531e99398286357d6e3a7c8036a8176896edfd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mps/__pycache__/profiler.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f9097c9f14f06374f22901477266615cf452112
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d62653e72b6f2e71788717eec82f3400cecf2cc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/memory.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/memory.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82fc4a1a8869dc89d65b17b6d5d8f3a5fce59e24
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/memory.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/mtia_graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/mtia_graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e70895c926e0e088513472d03c945d867918396
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/mtia/__pycache__/mtia_graph.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ae19697785c2a0331f2b9ec74e72219236877e1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64fbf6ab6b08d04ff747885ce96c0e3767a57001
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/pool.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/pool.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a13fc81c0dd0fb18546f46e48f9e644a89914ae
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/pool.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/queue.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/queue.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47b253c61bcf91e36596b808332b4788c8eda9bb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/queue.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9bbb0456025424ce340b44a285f0febb9e1d0cb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5353438a7c12622e6e9da834a9fd02677a35314
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a2a48381816bfe634c25248f0e591393cbfea542
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f4d1e61792dee0d143ae31fd55aeb79b84262f5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lower_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lower_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4957628f4fc7d9e99adfc2f146b7585266c6f948
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lower_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lowered_aoti_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lowered_aoti_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15da035af9e1ea24b6cd00a6a7d350225c4c6d26
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lowered_aoti_module.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/_lower_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/_lower_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa97bc30b4a047e270dd812b5676de354bf675f3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/_lower_utils.py
@@ -0,0 +1,101 @@
+import types
+
+import torch
+import torch.utils._pytree as pytree
+from torch.export import ExportedProgram
+from torch.export.pt2_archive._package import AOTI_FILES, package_pt2
+from torch.types import FileLike
+
+from ._lowered_aoti_module import LoweredBackendModule
+
+
+def get_new_ep_with_flat_inputs_outputs(ep: ExportedProgram) -> ExportedProgram:
+    class FlattenedModule(torch.nn.Module):
+        def __init__(
+            self,
+            original_module: torch.fx.GraphModule,
+            in_spec: pytree.TreeSpec,
+            out_spec: pytree.TreeSpec,
+        ) -> None:
+            super().__init__()
+            self.original_module = original_module
+            self.in_spec = in_spec
+            self.out_spec = out_spec
+
+        def forward(self, *flat_inputs):  # type: ignore[no-untyped-def]
+            # Unflatten inputs to original structure
+            inputs = pytree.tree_unflatten(flat_inputs, self.in_spec)
+            args, kwargs = inputs
+            outputs = self.original_module(*args, **kwargs)
+            # Flatten outputs
+            flat_outputs, _ = pytree.tree_flatten(outputs)
+            return tuple(flat_outputs)
+
+    flattened_module = FlattenedModule(
+        ep.module(), ep.call_spec.in_spec, ep.call_spec.out_spec
+    )
+    args, kwargs = ep.example_inputs
+    flat_inputs, _ = pytree.tree_flatten((args, kwargs))
+    flat_ep = torch.export.export(flattened_module, tuple(flat_inputs))
+
+    return flat_ep
+
+
+def lower_exported_program(
+    exported_program: ExportedProgram, model_name: str, backend_id: str
+) -> tuple[ExportedProgram, AOTI_FILES]:
+    """
+    Lower an exported program to AOTInductor and return a delegate ExportedProgram
+    with the `executorch_call_delegate` HOP
+    """
+    args, kwargs = exported_program.example_inputs
+    out_spec = exported_program.call_spec.out_spec
+    flat_ep = get_new_ep_with_flat_inputs_outputs(exported_program)
+    flat_inputs, _ = pytree.tree_flatten((args, kwargs))
+
+    aoti_files = torch._inductor.aot_compile(
+        flat_ep.module(), tuple(flat_inputs), options={"aot_inductor.package": True}
+    )
+    assert isinstance(aoti_files, list)
+
+    lowered_aoti_module = LoweredBackendModule(
+        flat_ep, backend_id, module_name=model_name
+    )
+
+    def patched_forward(self, *args, **kwargs):  # type: ignore[no-untyped-def]
+        flat_inputs, _ = pytree.tree_flatten((args, kwargs))
+        flat_outputs = torch._higher_order_ops.executorch_call_delegate(
+            self, *flat_inputs
+        )
+        if out_spec is not None and flat_outputs is not None:
+            return pytree.tree_unflatten(flat_outputs, out_spec)
+        else:
+            return flat_outputs
+
+    lowered_aoti_module.forward = types.MethodType(patched_forward, lowered_aoti_module)  # type: ignore[method-assign]
+
+    aoti_delegate_ep = torch.export.export(lowered_aoti_module, args, kwargs)
+
+    return aoti_delegate_ep, aoti_files
+
+
+def package_nativert_with_aoti_delegate(
+    f: FileLike,
+    model_name: str,
+    backend_id: str,
+    original_ep: ExportedProgram,
+    delegate_ep: ExportedProgram,
+    delegate_files: AOTI_FILES,
+) -> None:
+    """
+    Package a pt2 archive file that can be consumed by NativeRT with AOTI Delegate
+    """
+    package_pt2(
+        f,
+        exported_programs={
+            model_name: original_ep,
+            f"{model_name}-{backend_id}": delegate_ep,
+        },
+        aoti_files={f"{model_name}-{backend_id}": delegate_files},  # type: ignore[dict-item]
+    )
+    return
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/_lowered_aoti_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/_lowered_aoti_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..c08e83211ef330be11788b5ca82a1dcc9a0c9f9d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nativert/backends/_lowered_aoti_module.py
@@ -0,0 +1,31 @@
+import torch
+from torch.export import ExportedProgram
+
+
+class LoweredBackendModule(torch.nn.Module):
+    def __init__(
+        self,
+        original_exported_program: ExportedProgram,
+        backend_id: str,
+        *,
+        module_name: str | None = None,
+    ) -> None:
+        super().__init__()
+        self._backend_id = backend_id
+        self._module_name = module_name
+        self._original_exported_program = original_exported_program
+
+    @property
+    def backend_id(self) -> str:
+        return self._backend_id
+
+    @property
+    def module_name(self) -> str | None:
+        return self._module_name
+
+    @property
+    def original_module(self) -> ExportedProgram:
+        return self._original_exported_program
+
+    def forward(self, *args, **kwargs):  # type: ignore[no-untyped-def]
+        return torch._higher_order_ops.executorch_call_delegate(self, *args, **kwargs)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60f220a4cf05443575a59a72a2be27d1744de2f5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d901365e54fe9ab6a1221f0e6d4229657e2832a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07ca92f29da8d2080617e23a2bab6911cdb0ad13
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc34e79086ceb648351229feeb58ad5f26d03935
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b294b7bbe8cdd4da7dcb8b1ad2af2a128463843f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py
new file mode 100644
index 0000000000000000000000000000000000000000..b347258b5f463789aa1425f9a8d61de1e306bee7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py
@@ -0,0 +1,116 @@
+from typing import *  # noqa: F403
+
+import torch
+from torch.fx.experimental._constant_symnode import ConstantIntNode
+
+
+__all__ = ["NestedIntNode"]
+
+
+# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp
+def _eq(lhs: Any, rhs: Any) -> bool:
+    return (
+        isinstance(lhs, NestedIntNode)
+        and isinstance(rhs, NestedIntNode)
+        and lhs.t_id == rhs.t_id
+        and lhs.coeff == rhs.coeff
+    )
+
+
+def _ge(lhs: Any, rhs: Any) -> bool:
+    if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode):
+        if lhs.t_id == rhs.t_id:
+            return lhs.coeff >= rhs.coeff
+        raise ValueError("ge: relation is indeterminate")
+    elif isinstance(lhs, NestedIntNode):
+        if rhs.is_constant() and rhs.constant_int() <= 2:
+            return True
+        raise ValueError("ge: relation is indeterminate")
+    elif isinstance(rhs, NestedIntNode):
+        if lhs.is_constant() and lhs.constant_int() < 2:
+            return False
+        raise ValueError("ge: relation is indeterminate")
+    else:
+        raise ValueError("inputs unsupported")
+
+
+class NestedIntNode:
+    def __init__(self, t_id: int, coeff: int) -> None:
+        self.t_id = t_id
+        self.coeff = coeff
+
+    def nested_int_coeff(self) -> int:
+        return self.coeff
+
+    def maybe_as_int(self) -> Optional[int]:
+        return None
+
+    def is_int(self) -> bool:
+        return True
+
+    def is_float(self) -> bool:
+        return False
+
+    def is_bool(self) -> bool:
+        return False
+
+    def is_nested_int(self) -> bool:
+        return True
+
+    def clone(self) -> "NestedIntNode":
+        return self
+
+    def _str(self) -> Any:
+        if self.coeff == 1:
+            return f"j{self.t_id}"
+        return f"{self.coeff}*j{self.t_id}"
+
+    def str(self) -> Any:
+        return self._str()
+
+    def __str__(self) -> Any:
+        return self._str()
+
+    def __repr__(self) -> Any:
+        return self._str()
+
+    def _graph_repr(self) -> Any:
+        return self._str()
+
+    def mul(self, other: Any) -> "NestedIntNode":
+        if other.is_constant():
+            other = other.constant_int()
+        else:
+            raise ValueError(f"unsupported: {type(other)}")
+        return NestedIntNode(self.t_id, self.coeff * other)
+
+    def eq(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_eq(self, other))
+
+    def ne(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _eq(self, other))
+
+    def gt(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _ge(other, self))
+
+    def lt(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _ge(self, other))
+
+    def le(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_ge(other, self))
+
+    def ge(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_ge(self, other))
+
+    def is_symbolic(self) -> bool:
+        return False
+
+    def nested_int(self) -> int:
+        return self.t_id
+
+    def is_constant(self) -> bool:
+        return False
+
+    def wrap_int(self, num: int) -> ConstantIntNode:
+        assert type(num) is int
+        return ConstantIntNode(num)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf4e3fecf4e6cc3947d07757896a5eb1e9d7935b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py
@@ -0,0 +1,676 @@
+# mypy: allow-untyped-defs
+from typing import *  # noqa: F403
+
+import torch
+from torch._C import DispatchKey, DispatchKeySet
+from torch._prims_common import is_expandable_to
+from torch.nested._internal.nested_int import NestedIntNode
+from torch.utils.weak import WeakTensorKeyDictionary
+
+
+_tensor_id_counter = 0
+_tensor_symint_registry = WeakTensorKeyDictionary()
+
+
+def get_tensor_symint(tensor, *, coeff=1):
+    from torch._subclasses.fake_tensor import FakeTensor
+    from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
+
+    # NB: Only FakeTensor is associated with a memo
+    tensor = mb_unwrap_functional_tensor(tensor)
+    if isinstance(tensor, FakeTensor):
+        return tensor.get_nested_int(coeff=coeff)
+
+    global _tensor_id_counter
+
+    tensor_symint = _tensor_symint_registry.get(tensor)
+    if tensor_symint is None:
+        tensor_symint = torch.SymInt(NestedIntNode(_tensor_id_counter, coeff))
+        _tensor_id_counter += 1
+        _tensor_symint_registry[tensor] = tensor_symint
+    return tensor_symint
+
+
+# SDPA metadata; max / min seqlens are needed for e.g. flash
+def _get_sdpa_extreme_seqlen(func, tensor):
+    return int(func(tensor).item())
+
+
+def _store_val_in_tensor(val) -> torch.Tensor:
+    # hack to get dynamic shapes support: store in a (val, 0) shaped tensor
+    return torch.zeros(val, 0)
+
+
+def _load_val_from_tensor(t: torch.Tensor):
+    return t.shape[0]
+
+
+# serialization function must be defined at top level
+def _rebuild_njt(constructor_kwargs):
+    return NestedTensor(**constructor_kwargs)
+
+
+class NestedTensor(torch.Tensor):
+    _values: torch.Tensor  # type: ignore[assignment]
+    _offsets: torch.Tensor
+    _lengths: Optional[torch.Tensor]
+    # NOTE [ Nested ints for ragged sizes and strides ]
+    #
+    # Jagged layout tensors are tensors that represent a n-dim tensor with a
+    # ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
+    # a jagged tensor with outer shape [B, x, D] is represented internally by a
+    # tensor with shape [sum(x), D] where we introduce what we call a nested int
+    # denoted as "x" here (but sometimes denoted with "*" to
+    # represent the ragged dimension, and sum(x) represents the dim of the inner
+    # tensor or equivalently the sum of all the sizes of the constituent
+    # tensors' varying lengths.
+    #
+    # We also use nested ints to represent the strides of this tensor.
+    # For example, a jagged tensor with shape [B, x, D] can be strided in two
+    # ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
+    _size: tuple[int, ...]
+    _strides: tuple[int, ...]
+    # Indicates that the nth dimension is ragged
+    _ragged_idx: int
+    _metadata_cache: Dict[str, Any]
+
+    @staticmethod
+    def __new__(
+        cls,
+        values,
+        offsets,
+        *,
+        lengths=None,
+        **kwargs,
+    ):
+        ks = DispatchKeySet(DispatchKey.NestedTensor)
+        ks = ks.add(DispatchKey.AutogradNestedTensor)
+
+        # Only support jagged for now.
+        assert offsets is not None
+        assert offsets.ndim == 1
+        assert not isinstance(values, NestedTensor)
+        assert values.device == offsets.device
+
+        # Query cache for the symint associated with offsets or lengths
+        # (create a new one if needed).
+        ragged_source = offsets if lengths is None else lengths
+        ragged_size = get_tensor_symint(ragged_source, coeff=1)
+        _ragged_idx = kwargs.get("_ragged_idx", 1)
+        B = offsets.shape[0] - 1
+        if lengths is not None:
+            assert B == lengths.shape[0]
+
+        # subtract 1 to convert to values dim space
+        r = _ragged_idx - 1
+        _size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
+        stride = values.stride()
+        _strides = (ragged_size * stride[r], *stride)
+
+        r = torch.Tensor._make_wrapper_subclass(
+            cls,
+            _size,
+            _strides,
+            0,
+            torch.contiguous_format,
+            values.dtype,
+            torch.jagged,
+            values.device,
+            False,
+            kwargs.get("requires_grad", False),
+            "sizes",
+            False,
+            True,  # dispatch_layout
+            ks,
+            # don't try to calculate storage based on non-zero size
+            storage_size=values.untyped_storage().size(),
+        )
+        r._ragged_idx = _ragged_idx
+        r._size = _size
+        r._strides = _strides
+
+        return r
+
+    def __init__(self, values, offsets, *, lengths=None, **kwargs) -> None:
+        super().__init__()
+
+        self._values = values
+        self._offsets = offsets
+        self._lengths = lengths
+
+        # holds properties that are computed lazily
+        self._metadata_cache = kwargs.get("_metadata_cache") or {}
+
+        # collapsed ragged dim must always be dynamic
+        torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
+        torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
+
+        # min / max sequence length should be dynamic if present
+        max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None)
+        if max_seqlen_tensor is not None:
+            torch._dynamo.mark_dynamic(max_seqlen_tensor, 0)
+        min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None)
+        if min_seqlen_tensor is not None:
+            torch._dynamo.mark_dynamic(min_seqlen_tensor, 0)
+
+    def values(self):
+        # dispatch to get proper view relationship
+        return torch._nested_get_values(self)  # type: ignore[attr-defined]
+
+    def offsets(self):
+        return self._offsets
+
+    def lengths(self):
+        return self._lengths
+
+    # Private accessor functions for min / max sequence length. They're
+    # purposefully not @properties because those don't work with PT2 (yet).
+    # These compute / cache if not present.
+    # TODO: Revisit this when @properties are better supported by PT2. I think the ideal
+    # state would be to have public @properties for min / max sequence length that compile
+    # (including setters).
+    def _get_max_seqlen(self):
+        max_seqlen_tensor = self._max_seqlen_tensor
+        if max_seqlen_tensor is None:
+            # compute & cache
+            max_val = _get_sdpa_extreme_seqlen(
+                torch.max,
+                self._offsets.diff() if self._lengths is None else self._lengths,
+            )
+            max_seqlen_tensor = _store_val_in_tensor(max_val)
+            self._metadata_cache["max_seqlen"] = max_seqlen_tensor
+        return _load_val_from_tensor(max_seqlen_tensor)
+
+    def _get_min_seqlen(self):
+        min_seqlen_tensor = self._min_seqlen_tensor
+        if min_seqlen_tensor is None:
+            # compute & cache
+            min_val = _get_sdpa_extreme_seqlen(
+                torch.min,
+                self._offsets.diff() if self._lengths is None else self._lengths,
+            )
+            min_seqlen_tensor = _store_val_in_tensor(min_val)
+            self._metadata_cache["min_seqlen"] = min_seqlen_tensor
+        return _load_val_from_tensor(min_seqlen_tensor)
+
+    # Private accessors used for treating min / max seqlen as inner tensors for
+    # flatten / unflatten. These must be properties to work with the traceable wrapper
+    # subclass logic. These do not compute / cache if not present.
+    @property
+    def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
+        return self._metadata_cache.get("max_seqlen", None)
+
+    @_max_seqlen_tensor.setter
+    def _max_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
+        self._metadata_cache["max_seqlen"] = val
+
+    @property
+    def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
+        return self._metadata_cache.get("min_seqlen", None)
+
+    @_min_seqlen_tensor.setter
+    def _min_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
+        self._metadata_cache["min_seqlen"] = val
+
+    # These are old private @property accessors that are kept around for internal BC
+    # reasons. TODO: Remove these!
+    @property
+    def _max_seqlen(self):
+        return self._get_max_seqlen()
+
+    @property
+    def _min_seqlen(self):
+        return self._get_min_seqlen()
+
+    # Convenience accessors that return a min / max seqlen if one is present and do NOT
+    # compute / cache them if they're not.
+    @property
+    def _maybe_max_seqlen(self) -> Optional[int]:
+        mt = self._max_seqlen_tensor
+        return None if mt is None else _load_val_from_tensor(mt)
+
+    @property
+    def _maybe_min_seqlen(self) -> Optional[int]:
+        mt = self._min_seqlen_tensor
+        return None if mt is None else _load_val_from_tensor(mt)
+
+    def _is_contiguous_or_false(self):
+        if self.lengths() is not None:
+            return False
+        from torch._prims_common import is_contiguous_for_memory_format_or_false
+
+        return is_contiguous_for_memory_format_or_false(
+            self._values, memory_format=torch.contiguous_format
+        )
+
+    def __repr__(self) -> str:  # type: ignore[override]
+        # We should implement this in torch/_tensor_str.py instead
+        grad_fn_str = (
+            f", requires_grad={self.requires_grad}" if self.requires_grad else ""
+        )
+
+        if self.grad_fn:
+            grad_fn_str = f", grad_fn={self.grad_fn}"
+
+        return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
+
+    # TODO: Remove this in favor of the default tensor subclass serialization logic.
+    # We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.
+    def __reduce_ex__(self, proto):
+        state = torch._utils._get_obj_state(self)
+
+        # Cached PyCapsules for sizes / strides are not serializable.
+        # See Note [Tensor Subclass custom size/stride caching strategy]
+        self._clear_non_serializable_cached_data()
+        # SymNodes are not serializable
+        assert "_size" in state and "_strides" in state
+        state = dict(state)
+        del state["_size"]
+        del state["_strides"]
+
+        func = _rebuild_njt
+        constructor_kwargs = {
+            "values": self._values,
+            "offsets": self._offsets,
+            "lengths": self._lengths,
+            "_ragged_idx": self._ragged_idx,
+            "_metadata_cache": self._metadata_cache,
+            "requires_grad": self.requires_grad,
+        }
+        args = (constructor_kwargs,)
+        return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
+
+    def __tensor_flatten__(self):
+        ctx = {
+            "requires_grad": self.requires_grad,
+            "ragged_idx": self._ragged_idx,
+        }
+        inner_tensors = ["_values", "_offsets"]
+        if self._lengths is not None:
+            inner_tensors.append("_lengths")
+        if self._min_seqlen_tensor is not None:
+            inner_tensors.append("_min_seqlen_tensor")
+        if self._max_seqlen_tensor is not None:
+            inner_tensors.append("_max_seqlen_tensor")
+        return inner_tensors, ctx
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
+        from torch._subclasses.fake_tensor import FakeTensor
+
+        # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
+        assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
+        values = inner_tensors["_values"]
+        offsets = inner_tensors["_offsets"]
+        lengths = inner_tensors.get("_lengths", None)
+        min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None)
+        max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None)
+
+        metadata_cache = {}
+        if min_seqlen_tensor is not None:
+            metadata_cache["min_seqlen"] = min_seqlen_tensor
+        if max_seqlen_tensor is not None:
+            metadata_cache["max_seqlen"] = max_seqlen_tensor
+        ragged_idx = meta["ragged_idx"]
+
+        # Alternatively, we could make it the caller's responsibility to
+        # cache it. But this heuristic seems simple enough.
+        ragged_source = offsets if lengths is None else lengths
+        if isinstance(ragged_source, FakeTensor):
+            ragged_size = outer_size[ragged_idx]
+            ragged_source.nested_int_memo = ragged_size
+
+        return NestedTensor(
+            values,
+            offsets=offsets,
+            lengths=lengths,
+            requires_grad=meta["requires_grad"],
+            _ragged_idx=ragged_idx,
+            _metadata_cache=metadata_cache,
+        )
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):  # type: ignore[override]
+        # If you're wondering why there's a nested tensor with one of its
+        # size = -1, see note: [NJT outer_size in AOTDispatcher]
+        kwargs = {} if kwargs is None else kwargs
+
+        # Lazy import to avoid circular dependency
+        from .ops import lookup_jagged
+
+        fn = lookup_jagged(func, *args, **kwargs)
+        if fn is not None:
+            return fn(*args, **kwargs)
+
+        # Poor man's redispatch for composite ops. This becomes relevant under inference
+        # mode, where disabling autograd key dispatch prevents decomposition.
+        all_dks = (
+            # We want to handle both the cases where NestedTensor overrides the
+            # composite implicit autograd kernel, and the case where it doesn't.
+            # Prioritize calling into NestedTensor's kernel if it exists.
+            torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor,
+            torch._C.DispatchKey.CompositeImplicitAutograd,
+        )
+        for dk in all_dks:
+            if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
+                with torch.overrides.enable_reentrant_dispatch():
+                    return func._op_dk(dk, *args, **kwargs)
+
+        raise NotImplementedError(func)
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        if kwargs is None:
+            kwargs = {}
+
+        from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
+
+        from .ops import jagged_torch_function
+
+        # This should be removed after
+        # https://github.com/pytorch/pytorch/pull/125941/ lands
+        with maybe_enable_thunkify():
+            try:
+                return jagged_torch_function(func, *args, **kwargs)
+            except NotImplementedError:
+                pass
+            with torch._C.DisableTorchFunctionSubclass():
+                return func(*args, **kwargs)
+
+
+# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!
+# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the
+# internal BC period has passed.
+
+
+# Not actually a view!
+class ViewBufferFromNested(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x: NestedTensor):  # type: ignore[override]
+        ctx.save_for_backward(x.offsets())
+        ctx.metadata_cache = x._metadata_cache
+        ctx.ragged_idx = x._ragged_idx
+        return x._values
+
+    @staticmethod
+    def backward(ctx, gO: torch.Tensor):  # type: ignore[override]
+        (offsets,) = ctx.saved_tensors
+        return NestedTensor(
+            gO,
+            offsets=offsets,
+            _metadata_cache=ctx.metadata_cache,
+            _ragged_idx=ctx.ragged_idx,
+        )
+
+
+# Not actually a view!
+class ViewNestedFromBuffer(torch.autograd.Function):
+    @staticmethod
+    def forward(  # pyrefly: ignore  # bad-override
+        ctx,
+        values: torch.Tensor,
+        offsets: torch.Tensor,
+        metadata_cache: Optional[Dict[str, Any]] = None,
+    ):  # type: ignore[override]
+        # maintain BC with this usages of this where the seqlens are stuffed
+        # directly into the metadata cache as non-Tensors / ints
+        if metadata_cache is not None:
+            min_seqlen = metadata_cache.get("min_seqlen", None)
+            max_seqlen = metadata_cache.get("max_seqlen", None)
+            if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor):
+                metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen)
+            if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor):
+                metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen)
+        return NestedTensor(
+            values.detach(),
+            offsets=offsets,
+            _metadata_cache=metadata_cache,
+        )
+
+    @staticmethod
+    def backward(ctx, gO: NestedTensor):  # type: ignore[override]
+        return gO._values, None, None
+
+
+def buffer_from_jagged(jagged):
+    return ViewBufferFromNested.apply(jagged)
+
+
+# Need to make it obvious that users should be passing in offsets
+def jagged_from_list(
+    tensors: List[torch.Tensor],
+    offsets: Optional[torch.Tensor],
+    dtype=None,
+    device=None,
+) -> tuple[NestedTensor, torch.Tensor]:
+    """Constructs a NestedTensor backed by jagged layout from a list of tensors"""
+
+    if len(tensors) == 0:
+        raise RuntimeError("Cannot construct a nested tensor from an empty tensor list")
+    if not len(set(t.dtype for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must have the same dtype"
+        )
+    if not len(set(t.device for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must be on the same device"
+        )
+    if not len(set(t.dim() for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must have the same dim"
+        )
+    component_dim = tensors[0].dim()
+    if component_dim == 0:
+        raise RuntimeError(
+            "Cannot construct a nested tensor from a list of zero-dim tensors"
+        )
+
+    # Check that the NT is representable by the jagged layout, which
+    # allows for a single ragged dimension after the batch dim.
+    # e.g. (B, *, D_0, ..., D_N), (B, D_0, *, ..., D_N), etc.
+    sizes = [t.shape for t in tensors]
+    ragged_idx = None
+    for d in range(component_dim):
+        dim_is_ragged = any(size[d] != sizes[0][d] for size in sizes)
+        if dim_is_ragged:
+            if ragged_idx is None:
+                # add 1 to convert to outer NJT dim space
+                ragged_idx = d + 1
+            else:
+                raise RuntimeError(
+                    "Cannot represent given tensor list as a nested tensor with the jagged layout. "
+                    "Note that the jagged layout only allows for a single ragged dimension. "
+                    "For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim."
+                )
+
+    # allow for a rectangular NJT and default the ragged dim next to the batch dim
+    if ragged_idx is None:
+        ragged_idx = 1
+
+    # Set properties appropriately.
+    values = torch.cat(tensors, dim=(ragged_idx - 1))
+    to_kwargs = {}
+    if device is not None:
+        to_kwargs["device"] = device
+    if dtype is not None:
+        to_kwargs["dtype"] = dtype
+    values = values.to(**to_kwargs)
+
+    # Calculate jagged offsets if not provided.
+    if offsets is None:
+        # Jagged layout specifies that offsets are stored as int64 on the same device as values.
+        # TODO: An alternative way to construct offsets is to use F.pad. This avoids creating
+        # an extra leaf tensor during the forward, potentially resolving compatibility issues.
+        offsets = torch.cat(
+            [
+                torch.zeros(1, dtype=torch.int64, device=values.device),
+                torch.tensor(
+                    [s[ragged_idx - 1] for s in sizes], device=values.device
+                ).cumsum(dim=0),
+            ]
+        )
+
+    # compute this now since it's easy
+    min_seqlen = min(t.shape[ragged_idx - 1] for t in tensors)
+    max_seqlen = max(t.shape[ragged_idx - 1] for t in tensors)
+    ret_nt = nested_view_from_values_offsets(
+        values,
+        offsets,
+        min_seqlen=min_seqlen,
+        max_seqlen=max_seqlen,
+        ragged_idx=ragged_idx,
+    )
+    return (ret_nt, offsets)  # type: ignore[return-value]
+
+
+def jagged_from_tensor_and_lengths(
+    tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
+) -> tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
+    """Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
+    batch_size = tensor.shape[0]
+    if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
+        lengths.shape, (batch_size,)
+    ):
+        start_list = starts.expand(batch_size)
+        length_list = lengths.expand(batch_size)
+    else:
+        raise RuntimeError(
+            "When constructing a jagged nested tensor using narrow(), "
+            "your start and length must be Tensors that broadcast to input.shape[0]"
+        )
+
+    # Calculate jagged offsets
+    assert len(tensor.shape) >= 2, (
+        "tensor must at least be 2D for the nested narrow op to work"
+    )
+    max_seq_len = tensor.shape[1]
+    offset_lengths = max_seq_len * torch.arange(
+        0, batch_size, dtype=torch.int64, device=tensor.device
+    )
+    # Jagged layout specifies that offsets are stored as int64 on the same device as values.
+    offsets = torch.cat(
+        [
+            start_list + offset_lengths,
+            (start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
+        ]
+    )
+
+    # Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
+    if len(tensor.shape) > 2:
+        values = tensor.view(-1, *tensor.shape[2:])
+    else:
+        values = tensor.view(-1)
+
+    # Check if offsets and lengths make it possibly contiguous and return a regular NT
+    is_contiguous = True
+    orig_dim = tensor.shape[1]
+    if torch.any(length_list[1:-1].ne(orig_dim)):
+        is_contiguous = False
+    if torch.any(offsets[1:-2].diff().ne(orig_dim)):
+        is_contiguous = False
+    if offsets[0] + length_list[0] != orig_dim:
+        is_contiguous = False
+
+    actual_max_seqlen = int(torch.max(lengths).item())
+    min_seqlen = int(torch.min(lengths).item())
+
+    if is_contiguous:
+        ret_nt = nested_view_from_values_offsets(
+            values[offsets[0] : offsets[-1]],
+            offsets - offsets[0],
+            min_seqlen=min_seqlen,
+            max_seqlen=actual_max_seqlen,
+        )
+    else:
+        ret_nt = nested_view_from_values_offsets_lengths(
+            values,
+            offsets,
+            length_list,
+            min_seqlen=min_seqlen,
+            max_seqlen=actual_max_seqlen,
+        )
+
+    return (ret_nt, offsets, None if is_contiguous else length_list)
+
+
+# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked
+# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be
+# 0/1 because the dummy can be fake-ified and we want to avoid specializing.
+# This arg is otherwise unused.
+_dummy_instance: Optional[torch.Tensor] = None
+
+
+def _nt_view_dummy() -> torch.Tensor:
+    global _dummy_instance
+    if _dummy_instance is None:
+        _dummy_instance = NestedTensor(
+            values=torch.zeros(3, 3, device="meta"),
+            offsets=torch.zeros(3, device="meta", dtype=torch.int64),
+        ).detach()
+    return _dummy_instance
+
+
+def nested_view_from_values_offsets(
+    values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
+        values,
+        offsets,
+        _nt_view_dummy(),
+        None,
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+    )  # type: ignore[return-value]
+
+
+def nested_view_from_values_offsets_lengths(
+    values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
+        values,
+        offsets,
+        _nt_view_dummy(),
+        lengths,
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+    )  # type: ignore[return-value]
+
+
+def nested_from_padded(
+    padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_from_padded_tensor(
+        padded,
+        offsets,
+        _nt_view_dummy(),
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+        sum_S,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..200ccd653f6c3b4e9eeca8c28468c362cae93e86
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/ops.py
@@ -0,0 +1,2748 @@
+# mypy: allow-untyped-defs
+import functools
+import math
+import operator
+from typing import *  # noqa: F403
+
+import torch
+import torch.nn.functional as F
+from torch.fx.operator_schemas import normalize_function
+from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
+
+from .nested_tensor import NestedTensor
+
+
+__all__: list[Any] = []
+
+JAGGED_OPS_TABLE: Dict[Any, Any] = {}
+
+
+def _get_padding_value(dtype, padding_type):
+    if dtype.is_floating_point:
+        return (
+            torch.finfo(dtype).max if padding_type == "max" else torch.finfo(dtype).min
+        )
+    elif dtype == torch.int64:
+        # Largest int64 value exactly representable in float64 (IEEE 754 double precision).
+        # Avoids overflow when padding_value is passed as double to _jagged_to_padded_dense_forward.
+        int64_safe_max = (1 << 53) - 1
+        int64_safe_min = -int64_safe_max
+        return int64_safe_max if padding_type == "max" else int64_safe_min
+    else:
+        return (
+            torch.iinfo(dtype).max if padding_type == "max" else torch.iinfo(dtype).min
+        )
+
+
+def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
+    from torch._prims_common import canonicalize_dims
+
+    if isinstance(dim, (tuple, list)):
+        output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
+        # ensure no duplicates, which can result from both batch and ragged mapping to 0
+        return type(output)(dict.fromkeys(output))
+
+    if canonicalize:
+        dim = canonicalize_dims(ndim, dim)
+
+    assert dim >= 0 and dim < ndim  # pyrefly: ignore [unsupported-operation]
+
+    # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
+    # For other dims, subtract 1 to convert to inner space.
+    return (
+        # pyrefly: ignore [unsupported-operation]
+        ragged_dim - 1 if dim == 0 else dim - 1
+    )
+
+
+def _wrap_jagged_dim(
+    ndim,
+    dim,
+    ragged_dim,
+    op_name,
+    convert_to_inner_dim=True,
+    allow_ragged_dim=False,
+    allow_batch_dim=False,
+):
+    from torch._prims_common import canonicalize_dims
+
+    wrapped = canonicalize_dims(ndim, dim)
+    if wrapped == ragged_dim and not allow_ragged_dim:
+        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
+    elif wrapped == 0 and not allow_batch_dim:
+        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
+    ret = (
+        _outer_to_inner_dim(ndim, wrapped, ragged_dim)
+        if convert_to_inner_dim
+        else wrapped
+    )
+    if allow_batch_dim:
+        # Need to disambiguate whether we're operating on the batch dim or not.
+        # Operating on dim=1 -> dim=0 after the inner dim conversion.
+        operating_on_batch = wrapped == 0
+        return (ret, operating_on_batch)
+    return ret
+
+
+def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
+    """
+    For NestedTensor operators,
+    wraps dimensions to non-negative values,
+    and returns metadata related to reduction dimension(s).
+    """
+    from torch._prims_common import canonicalize_dims
+
+    assert isinstance(dims, (tuple, list)), (
+        f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
+    )
+
+    wrapped_dims = [
+        canonicalize_dims(ndim, d) for d in dims
+    ]  # convert all indices to non-negative values
+
+    operate_on_batch = 0 in wrapped_dims
+    operate_on_ragged = ragged_idx in wrapped_dims
+    operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
+
+    # ensure no duplicates, which can result from both batch and ragged mapping to 0
+    outer_to_inner_dim = tuple(
+        dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
+    )
+
+    return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
+
+
+def check_schema(schema_str: str, func, *args, **kwargs) -> None:
+    named_arg_types = schema_str.split(", ")
+    num_optional_args = [x.endswith("?") for x in named_arg_types].count(True)
+    min_args = len(named_arg_types) - num_optional_args
+
+    # special case: ellipses allows for any number of unchecked args at the end
+    if named_arg_types[-1] == "...":
+        named_arg_types = named_arg_types[:-1]
+    else:
+        if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
+            raise ValueError(
+                f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
+                f"arguments and at most {len(named_arg_types)} arguments, but got: "
+                f"{len(args)} arguments"
+            )
+
+    arg_type_check_fns = {
+        "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
+        "jt": lambda x: isinstance(x, NestedTensor)
+        and x._lengths is None
+        and x._ragged_idx == 1,  # ops with "jt" require contiguous JT only
+        "jt_all": lambda x: isinstance(
+            x, NestedTensor
+        ),  # ops with "jt_all" can accept all kinds of JT
+        "any": lambda x: True,
+    }
+    for i, named_arg_type in enumerate(named_arg_types):
+        name, arg_type = named_arg_type.split(": ")
+        is_optional = arg_type.endswith("?")
+        normalized_arg_type = arg_type[:-1] if is_optional else arg_type
+        if normalized_arg_type not in arg_type_check_fns:
+            raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
+
+        if i >= len(args):
+            if not is_optional:
+                raise ValueError(
+                    f"NestedTensor {func.__name__}({schema_str}) "
+                    f"missing required argument: {name}"
+                )
+            continue
+
+        _check_fn = arg_type_check_fns[normalized_arg_type]
+
+        def check_fn(x, is_optional=is_optional):
+            if is_optional:
+                return x is None or _check_fn(x)
+            else:
+                return _check_fn(x)
+
+        if not check_fn(args[i]):
+            type_to_desc = {
+                "t": "tensor",
+                "t?": "optional tensor",
+                "jt": "contiguous jagged layout NestedTensor",
+                "jt_all": "jagged layout NestedTensor",
+                "any": "",
+            }
+
+            raise ValueError(
+                f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a "
+                f"{type_to_desc[arg_type]}"
+            )
+
+
+def check_ragged_dim_same(
+    func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
+) -> None:
+    # Calling into .shape here
+    if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
+        raise RuntimeError(
+            f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
+            "same exact offsets tensor."
+        )
+
+
+# returns True if the raggedness-relevant portions of the NT shape
+# match those of the specified size
+def raggedness_matches(nt, size):
+    end = nt._ragged_idx + 1
+    nt_ragged = nt._size[:end]
+    size_ragged = size[:end]
+    return len(nt_ragged) == len(size_ragged) and (
+        all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged))
+    )
+
+
+def squeeze_leading_ones(t):
+    # Note: [ Squeezing leading ones ]
+    #
+    # Squeeze leading ones from t.
+    #
+    # We want:
+    #   (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    #   (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)  (not yet supported)
+    #
+    # 1) Squeeze extra ones and grab values from NT
+    #   (1, 1, ?, ?) -> (?, ?)   and   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    # 2) Do dense broadcasting:
+    #   (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
+    # 3) Construct nested tensor
+    #   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    #
+    # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
+    # at step (4) and we would need to update this function to record how
+    # many ones we unsqueezed.
+    while t.dim() > 0 and t.shape[0] == 1:
+        t = t.squeeze(0)
+    return t
+
+
+def register_func(tables, aten_ops, schema_str):
+    if not isinstance(aten_ops, list):
+        aten_ops = [aten_ops]
+    if not isinstance(tables, list):
+        tables = [tables]
+
+    def wrapper(func):
+        for aten_op in aten_ops:
+
+            def get_inner(aten_op):
+                def inner(*args, **kwargs):
+                    check_schema(schema_str, func, *args, **kwargs)
+                    return func(aten_op, *args, **kwargs)
+
+                return inner
+
+            for table in tables:
+                table[aten_op] = get_inner(aten_op)
+        return func
+
+    return wrapper
+
+
+register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
+
+
+def lookup_jagged(func, *args, **kwargs) -> Callable | None:
+    dispatch_func = JAGGED_OPS_TABLE.get(func, None)
+    if dispatch_func is not None:
+        return dispatch_func
+
+    # Handle pointwise fallbacks
+    if torch.Tag.pointwise in func.tags:
+        from torch.fx.experimental.symbolic_shapes import is_nested_int
+
+        # No pointwise ops legitimately accept nested int inputs. Without this check,
+        # they will be incorrectly interpreted as tensors.
+        # See https://github.com/pytorch/pytorch/issues/138496
+        for arg in args:
+            if is_nested_int(arg):
+                raise RuntimeError(
+                    f"NestedTensor {func.__name__}: invalid argument {arg}"
+                )
+
+        # Assume there aren't additional tensors that aren't the "unary/binary" args
+        num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args)
+        if num_tensor_args == 1:
+            # Build up the check schema string. The first tensor arg is assumed to be
+            # an NJT and other args are sent through as-is.
+            schema_parts = []
+            for arg in func._schema.arguments:
+                if isinstance(arg.type, torch.TensorType):
+                    schema_parts.append(f"{arg.name}: jt_all")
+                    break
+                else:
+                    schema_parts.append(f"{arg.name}: any")
+            schema_parts.append("...")
+            check_schema_str = ", ".join(schema_parts)
+            check_schema(check_schema_str, func, *args, **kwargs)
+            return functools.partial(jagged_unary_pointwise, func)
+        elif num_tensor_args == 2:
+            check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
+            return functools.partial(jagged_binary_pointwise, func)
+
+    return None
+
+
+def extract_kwargs(arg):
+    kwargs = {
+        "offsets": arg.offsets(),
+        "lengths": arg.lengths(),
+        "_metadata_cache": arg._metadata_cache,
+        "_ragged_idx": arg._ragged_idx,
+    }
+    return kwargs
+
+
+def jagged_unary_pointwise(func, *args, **kwargs):
+    # assume if we get here that there is a single NJT input in the args
+    njt = next(arg for arg in args if isinstance(arg, NestedTensor))
+    return NestedTensor(
+        func(*(arg._values if arg is njt else arg for arg in args), **kwargs),
+        **extract_kwargs(njt),
+    )
+
+
+def jagged_binary_pointwise(func, *args, **kwargs):
+    a, b = args[0], args[1]
+    assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor)
+
+    mismatch_error_msg = (
+        "cannot call binary pointwise function {} with inputs of shapes {} and {}"
+    )
+    # a is NT, b is NT
+    if isinstance(a, NestedTensor) and isinstance(b, NestedTensor):
+        # ex: (B, j0, D) + (B, j0, D)
+        # ex: (B, j0, D) + (B, j0, 1)
+        if raggedness_matches(a, b._size):
+            return NestedTensor(
+                func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a)
+            )
+        raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size))
+    # either a is NT or b is NT at this point
+    a_is_nt = isinstance(a, NestedTensor)
+    extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b)
+
+    # === Handle broadcasting across the batch / ragged dims ===
+
+    # Easy case: take advantage of pre-existing broadcasting logic
+    # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    nt, t = (a, b) if a_is_nt else (b, a)
+    # See Note: [ Squeezing leading ones ]
+    if t.dim() > nt.dim():
+        raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
+    t_squeezed = squeeze_leading_ones(t)
+    if nt.dim() >= t_squeezed.dim() + 2:
+        lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
+        return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
+
+    # Harder case: do manual broadcasting when NT dim == non-NT dim
+    # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1)
+    if a.dim() == b.dim():
+        # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should
+        # be (B, j0, D_0, D_1) but not yet supported
+        if a.shape[0] != b.shape[0]:
+            raise RuntimeError(
+                mismatch_error_msg.format(func.__name__, a.shape, b.shape)
+            )
+
+        from .nested_tensor import nested_from_padded
+
+        # handle broadcasting via padded dense -> jagged conversion
+        min_seqlen = nt._maybe_min_seqlen
+        max_seqlen = nt._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = nt._values.shape[nt._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        # convert dense tensor -> jagged
+        t = t.expand(
+            [x if i != nt._ragged_idx else padded_max_S for i, x in enumerate(t.shape)]
+        )
+        t_as_nt = nested_from_padded(
+            t,
+            offsets=nt._offsets,
+            ragged_idx=nt._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+        # function call with two NJTs
+        lhs, rhs = (nt, t_as_nt) if a_is_nt else (t_as_nt, nt)
+        return func(lhs, rhs, *args[2:], **kwargs)
+
+    # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant
+    # that ragged dim is wrt left-most batch dim
+    raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape))
+
+
+def jagged_torch_function(func, *args, **kwargs):
+    # SDPA has special kernels that handle nested tensors.
+    # Dispatch to the correct implementation here
+    if func is torch._C._nn.scaled_dot_product_attention:
+        return jagged_scaled_dot_product_attention(*args, **kwargs)
+
+    if func.__name__ == "apply_":
+        func(args[0]._values, *args[1:], **kwargs)
+        return args[0]
+
+    # Handle flatten() here because it's CompositeImplicit.
+    if func.__name__ == "flatten":
+
+        def _flatten_sig(input, start_dim=0, end_dim=-1) -> None:
+            pass
+
+        _, new_kwargs = normalize_function(  # type: ignore[misc]
+            _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+        )
+
+        inp = new_kwargs.pop("input")
+
+        # NB: stay in outer dim space because we're going to redispatch on a NT input
+        start_dim = _wrap_jagged_dim(
+            inp.dim(),
+            new_kwargs["start_dim"],
+            inp._ragged_idx,
+            "flatten",
+            convert_to_inner_dim=False,
+        )
+        end_dim = _wrap_jagged_dim(
+            inp.dim(),
+            new_kwargs["end_dim"],
+            inp._ragged_idx,
+            "flatten",
+            convert_to_inner_dim=False,
+        )
+
+        if start_dim == end_dim:
+            return inp
+
+        product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1])
+        new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :])
+
+        return inp.reshape(*new_shape)
+
+    # Handle NestedTensor share_memory_.
+    if func.__name__ == "share_memory_":
+        nt = args[0]
+
+        if nt.is_cuda:
+            return nt
+
+        names, _ = nt.__tensor_flatten__()
+        with torch._C.DisableTorchFunctionSubclass():
+            for name in names:
+                component = getattr(nt, name, None)
+                if component is not None:
+                    component.share_memory_()
+        return nt
+
+    # Handle NestedTensor is_shared.
+    if func.__name__ == "is_shared":
+        nt = args[0]
+
+        if nt.is_cuda:
+            return False
+
+        names, _ = nt.__tensor_flatten__()
+        if not names:
+            return False
+        return all(
+            getattr(nt, name) is not None and getattr(nt, name).is_shared()
+            for name in names
+        )
+
+    # Handle nested-specific input validation for CompositeImplicit rms_norm
+    if func.__name__ == "rms_norm":
+
+        def _rms_norm_sig(input, normalized_shape, weight=None, eps=None) -> None:
+            pass
+
+        _, new_kwargs = normalize_function(  # type: ignore[misc]
+            _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+        )
+
+        inp = new_kwargs.pop("input")
+        normalized_shape = new_kwargs.pop("normalized_shape")
+
+        # can't normalize over the ragged dim (yet)
+        max_normalizable = inp.dim() - inp._ragged_idx - 1
+        if len(normalized_shape) > max_normalizable:
+            raise ValueError(
+                "rms_norm(): Normalization over the ragged dim not supported for nested tensors"
+            )
+
+        with torch._C.DisableTorchFunctionSubclass():
+            return func(*args, **kwargs)
+
+    raise NotImplementedError(func)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.is_non_overlapping_and_dense.default,
+        torch.ops.aten.sym_size.default,
+        torch.ops.aten.dim.default,
+        torch.ops.aten.numel.default,
+        torch.ops.aten.sym_numel.default,
+        torch.ops.aten.sym_stride.default,
+        torch.ops.aten.sym_storage_offset.default,
+    ],
+    "self: jt_all",
+)
+def tensor_attr_supported_getter(func, *args, **kwargs):
+    if func is torch.ops.aten.is_non_overlapping_and_dense.default:
+        return False
+
+    if func is torch.ops.aten.sym_size.default:
+        return args[0]._size
+
+    if func is torch.ops.aten.dim.default:
+        return len(args[0]._size)
+
+    if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
+        if args[0]._lengths is not None:
+            return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
+        return args[0]._values.numel()
+
+    if func is torch.ops.aten.sym_stride.default:
+        return args[0]._strides
+
+    if func is torch.ops.aten.sym_storage_offset.default:
+        return args[0]._values.storage_offset()
+
+
+@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
+def prim_layout_default(func, *args, **kwargs):
+    return torch.jagged
+
+
+@register_jagged_func(
+    [torch.ops.aten.size.default],
+    "self: jt_all",
+)
+def tensor_attr_unsupported_getter(func, *args, **kwargs) -> None:
+    if func is torch.ops.aten.size.default:
+        raise RuntimeError(
+            "NestedTensor does not support directly calling torch.ops.aten.size; "
+            "please use `nested_tensor.size()` instead."
+        )
+
+
+@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
+def is_contiguous_general(func, *args, **kwargs):
+    from torch._prims_common import is_contiguous_for_memory_format
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    # If created from narrow() check for lengths
+    if inp.lengths() is not None:
+        return False
+
+    new_kwargs["memory_format"] = new_kwargs.get(
+        "memory_format", torch.contiguous_format
+    )
+    if new_kwargs["memory_format"] == torch.preserve_format:
+        return True
+    return is_contiguous_for_memory_format(inp._values, **new_kwargs)
+
+
+register_jagged_func(
+    torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
+)(is_contiguous_general)
+
+
+@register_jagged_func(
+    torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
+)
+def sym_is_contiguous_general(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    # If created from narrow() check for lengths
+    if inp.lengths() is not None:
+        return False
+
+    new_kwargs["memory_format"] = new_kwargs.get(
+        "memory_format", torch.contiguous_format
+    )
+
+    if new_kwargs["memory_format"] == torch.preserve_format:
+        return True
+
+    return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
+)
+def clone_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_meta = extract_kwargs(inp)
+
+    if inp._lengths is not None:
+        if new_kwargs["memory_format"] == torch.contiguous_format:
+            # need to copy to remove "holes" non-contiguity / lengths metadata
+            # TODO: write a kernel for this
+            from .nested_tensor import jagged_from_list
+
+            # TODO: We probably want the output to have the same ragged structure / nested int.
+            assert inp._ragged_idx == 1, (
+                "NJT with ragged_idx != 1 not supported for contiguous clone"
+            )
+            contig, _ = jagged_from_list(inp.unbind(), offsets=None)
+            return contig
+
+    return NestedTensor(func(inp._values, **new_kwargs), **new_meta)
+
+
+@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
+def linear_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.linear_backward.default,
+    "self: jt, grad_output: jt, weight: t, output_mask: any",
+)
+def linear_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    grad_output = new_kwargs.pop("grad_output")
+    weight = new_kwargs.pop("weight")
+    output_mask = new_kwargs.pop("output_mask")
+
+    ds, dw, db = None, None, None
+    check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
+    if output_mask[0]:
+        ds = NestedTensor(
+            torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output)
+        )
+    if output_mask[1]:
+        # NB: Fold dims of values for input and grad_output to treat them as 2D. This
+        # trick avoids materializing large intermediates and immediately reducing over
+        # them via sum(). This is equivalent to computing:
+        #     torch.matmul(grad_output._values.transpose(-2, -1), inp._values)
+        # and then summing over the leading dimensions to get a 2D weight grad.
+        grad_2d = grad_output._values.reshape(-1, weight.size(0))
+        input_2d = inp._values.reshape(-1, weight.size(1))
+        dw = torch.matmul(grad_2d.t(), input_2d)
+    if output_mask[2]:
+        # Sum over all but the last dim to get a 1D bias grad. We cannot
+        # rely on the autograd engine to reduce for us, because returning a
+        # tensor aliasing the input would violate the aten signature annotation
+        reduce_dims = tuple(range(grad_output._values.ndim - 1))
+        if reduce_dims == ():
+            db = grad_output._values.clone()
+        else:
+            db = torch.sum(grad_output._values, reduce_dims, keepdim=False)
+    return (ds, dw, db)
+
+
+@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
+def to_dtype(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
+def to_copy_default(func, *args, **kwargs):
+    from .nested_tensor import _tensor_symint_registry
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    # don't change layout
+    new_kwargs.pop("layout")
+
+    new_values = func(inp._values, **new_kwargs)
+    new_offsets = inp._offsets.to(device=new_values.device)
+    new_lengths = None
+    if inp._lengths is not None:
+        new_lengths = inp._lengths.to(device=new_values.device)
+
+    from torch._subclasses.fake_tensor import FakeTensor
+    from torch._subclasses.functional_tensor import (
+        FunctionalTensor,
+        mb_unwrap_functional_tensor,
+    )
+
+    ragged_source = inp._offsets if inp._lengths is None else inp._lengths
+    new_thing = new_offsets if new_lengths is None else new_lengths
+    if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
+        # Temporary hack until we have the union find
+        tgt = mb_unwrap_functional_tensor(new_thing)
+        src = mb_unwrap_functional_tensor(ragged_source)
+        tgt.nested_int_memo = src.nested_int_memo
+    else:
+        _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
+    inp_kwargs = extract_kwargs(inp)
+    inp_kwargs["offsets"] = new_offsets
+    inp_kwargs["lengths"] = new_lengths
+
+    output = NestedTensor(new_values, **inp_kwargs)
+    return output
+
+
+@register_jagged_func(
+    torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
+)
+def copy_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    src = new_kwargs.pop("src")
+    if inp._size != src._size:
+        # try to recursively copy_ on unbound components to get around nested int mismatch
+        # TODO: eventually do a direct copy when this is possible
+        inp_comps = inp.unbind()
+        inp_comp_shapes = [c.shape for c in inp_comps]
+        src_comps = src.unbind()
+        src_comp_shapes = [c.shape for c in src_comps]
+        if inp_comp_shapes != src_comp_shapes:
+            raise RuntimeError(
+                "copy_(): expected compatible input and src shapes, but got: "
+                f"{inp.shape} and {src.shape}"
+            )
+        for inp_comp, src_comp in zip(inp_comps, src_comps):
+            inp_comp.copy_(src_comp)
+
+    # AOTD allows mutations of inputs only, (not views of the inputs).
+    # NJT.values() returns _values.detach() to workaround some issues.
+    # To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).
+    # Here we directly mutate self._values to not emit .detach() in the graph, which would make it non-compilable.
+    inp._values.copy_(src._values)
+    return inp
+
+
+register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
+    jagged_unary_pointwise
+)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.empty_like.default,
+        torch.ops.aten.ones_like.default,
+        torch.ops.aten.zeros_like.default,
+        torch.ops.aten.rand_like.default,
+        torch.ops.aten.randn_like.default,
+    ],
+    "self: jt_all",
+)
+def like_factory_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # Default layout is technically torch.strided but only jagged is supported here.
+    # Rather than force users to specify the layout, assume jagged.
+    # This should be set to strided for redispatching on values.
+    new_kwargs["layout"] = torch.strided
+
+    new_values = func(inp._values, **new_kwargs)
+    new_offsets = inp._offsets.to(device=new_values.device)
+    new_lengths = None
+    if inp._lengths is not None:
+        new_lengths = inp._lengths.to(device=new_values.device)
+    output_kwargs = extract_kwargs(inp)
+    if "offsets" in output_kwargs:
+        output_kwargs["offsets"] = new_offsets
+    if "lengths" in output_kwargs:
+        output_kwargs["lengths"] = new_lengths
+
+    if inp.device != new_values.device:
+        # Update the nested int registry to indicate that the ragged structure is the same
+        # between the two offsets / lengths on different devices.
+        from torch._subclasses.fake_tensor import FakeTensor
+        from torch._subclasses.functional_tensor import (
+            FunctionalTensor,
+            mb_unwrap_functional_tensor,
+        )
+
+        from .nested_tensor import _tensor_symint_registry
+
+        ragged_source = inp._offsets if inp._lengths is None else inp._lengths
+        new_thing = new_offsets if new_lengths is None else new_lengths
+        if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
+            # Temporary hack until we have the union find
+            tgt = mb_unwrap_functional_tensor(new_thing)
+            src = mb_unwrap_functional_tensor(ragged_source)
+            tgt.nested_int_memo = src.nested_int_memo
+        else:
+            _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
+
+    return NestedTensor(new_values, **output_kwargs)
+
+
+register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
+    like_factory_default
+)
+
+register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
+    like_factory_default
+)
+
+register_jagged_func(
+    torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
+)(like_factory_default)
+
+
+@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
+def zero__default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    func(inp._values)
+    return inp
+
+
+@register_jagged_func(
+    torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
+)
+def _softmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    if isinstance(new_kwargs["dim"], tuple):
+        raise RuntimeError(
+            "softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
+        )
+
+    inp = new_kwargs.pop("input")
+
+    (
+        new_kwargs["dim"],
+        reduce_on_batch,
+        reduce_on_ragged,
+        _reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(),
+        (new_kwargs["dim"],),
+        "softmax",
+        inp._ragged_idx,
+    )
+
+    if reduce_on_batch:
+        raise RuntimeError(
+            "softmax(): not supported when reducing across the batch dimension for NestedTensor"
+        )
+
+    if reduce_on_ragged and inp._ragged_idx > 1:
+        raise RuntimeError(
+            "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor"
+        )
+
+    if reduce_on_ragged and inp._lengths is not None:
+        raise RuntimeError(
+            "softmax(): not supported where lengths is not None "
+            + "if reducing across the ragged dimension for NestedTensor"
+        )
+
+    new_kwargs["dim"] = new_kwargs["dim"][
+        0
+    ]  # torch.softmax takes in the reduction dimension as an integer
+
+    if reduce_on_ragged:
+        padded_softmax_values = torch.nn.functional.softmax(
+            torch.ops.aten._jagged_to_padded_dense_forward(
+                inp._values.reshape(
+                    inp._values.shape[0], -1
+                ),  # values are required to be 2D tensors for j2pd
+                [inp._offsets],
+                max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+                padding_value=float("-inf"),  # e^-inf = 0
+            ),
+            dim=inp._ragged_idx,
+        )
+
+        softmax_values = torch.ops.aten._padded_dense_to_jagged_forward(
+            padded_softmax_values,
+            [inp._offsets],
+            total_L=inp._values.shape[
+                0
+            ],  # providing this parameter helps avoid a GPU/CPU sync
+        ).reshape(
+            -1, *inp._values.shape[1:]
+        )  # expand softmax_values back to original shape (inp._values.shape)
+
+        return NestedTensor(softmax_values, **extract_kwargs(inp))
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any"
+)
+def _log_softmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    if isinstance(new_kwargs["dim"], tuple):
+        raise RuntimeError(
+            "log_softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
+        )
+
+    inp = new_kwargs.pop("input")
+
+    (
+        new_kwargs["dim"],
+        reduce_on_batch,
+        reduce_on_ragged,
+        _reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(), (new_kwargs["dim"],), "log_softmax", inp._ragged_idx
+    )
+
+    if reduce_on_batch:
+        raise RuntimeError(
+            "log_softmax(): not supported when reducing across the batch dimension for NestedTensor"
+        )
+
+    if reduce_on_ragged:
+        raise RuntimeError(
+            "log_softmax(): not supported when reducing along the ragged dimension for NestedTensor"
+        )
+
+    # torch.log_softmax takes in the reduction dimension as an integer
+    new_kwargs["dim"] = new_kwargs["dim"][0]
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten._softmax_backward_data.default,
+    "grad_output: jt, output: jt, dim: any, input_dtype: any",
+)
+def _softmax_backward(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_out = new_kwargs.pop("grad_output")
+    output = new_kwargs.pop("output")
+    return NestedTensor(
+        func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
+)
+def native_dropout_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    out1, out2 = func(inp._values, **new_kwargs)
+    return (
+        NestedTensor(out1, **extract_kwargs(inp)),
+        NestedTensor(out2, **extract_kwargs(inp)),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.native_dropout_backward.default,
+    "grad_output: jt, mask: jt, scale: any",
+)
+def native_dropout_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_output = new_kwargs.pop("grad_output")
+    mask = new_kwargs.pop("mask")
+    return NestedTensor(
+        func(grad_output._values, mask._values, **new_kwargs),
+        **extract_kwargs(grad_output),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.prod.dim_int,
+    "self: jt_all, dim: any, keepdim: any?, dtype: any?",
+)
+def prod_dim_int(func, *args, **kwargs):
+    return _apply_reduction(func, "prod", 1, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.prod.default, "self: jt_all, dtype: any?")
+def prod_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any?"
+)
+def split_tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split"
+    )
+
+    return tuple(
+        NestedTensor(values=x, **extract_kwargs(inp))
+        for x in func(inp._values, **new_kwargs)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any?"
+)
+def split_with_sizes_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split_with_sizes"
+    )
+
+    return [
+        NestedTensor(values=x, **extract_kwargs(inp))
+        for x in func(inp._values, **new_kwargs)
+    ]
+
+
+@register_jagged_func(
+    torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
+)
+def narrow(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
+    values = func(
+        inp._values,
+        dim=dim,
+        start=new_kwargs["start"],
+        length=new_kwargs["length"],
+    )
+    return NestedTensor(values, **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
+def chunk_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "chunk", allow_batch_dim=True
+    )
+
+    if operating_on_batch:
+        chunks = new_kwargs["chunks"]
+
+        # get _offsets of the chunks
+        lengths = inp._offsets.diff()
+        chunked_lengths = lengths.chunk(chunks)
+        chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths]
+        chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets]  # type: ignore[arg-type]
+        nested_kwargs = [
+            {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx}
+            for per_offsets in chunked_offsets
+        ]
+
+        # get _values of the chunks
+        split_sizes = [x.sum().item() for x in chunked_lengths]
+        chunk_values = inp._values.split(split_sizes)
+
+        # Note that the actual number of chunks returned is not necessarily the same as
+        # the input number; it can be counter-intuitive, but it matches dense behavior.
+        return [
+            NestedTensor(values=chunk_values[i], **(nested_kwargs[i]))
+            for i in range(len(chunk_values))
+        ]
+    else:
+        return [
+            NestedTensor(values=x, **extract_kwargs(inp))
+            for x in func(inp._values, **new_kwargs)
+        ]
+
+
+@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
+def unbind_int(func, *args, **kwargs):
+    # Note that this specializes on the length of the offsets
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dim = new_kwargs["dim"]
+    if dim != 0:
+        raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
+
+    inp = new_kwargs.pop("input")
+    values = inp.values()
+    offsets = inp.offsets()
+    lengths = inp.lengths()
+    ragged_idx = inp._ragged_idx
+
+    def _torch_check(_lengths: list[int], _offsets: list[int] | None = None) -> None:
+        # This torch._check are needed for torch.compile
+        # symbolic shapes processing.
+        # offsets and lengths are symbolic variables during compilation,
+        # we guarantee the correct offsets/lengths correspondence:
+        # sum of lengths <= total ragged_dim_size
+        # every length and offset are size-like variable (allows sym shapes to reason it as [2, inf))
+        # offset[i] + length[i] <= ragged_dim_size, for unbind and split dim correctness
+        # offsets[i] <= ragged_dim_size
+
+        lengths_sum = 0
+        ragged_dim_size = values.shape[ragged_idx - 1]
+        for i in range(len(_lengths)):
+            torch._check(_lengths[i] >= 0)
+            torch._check(_lengths[i] <= ragged_dim_size)
+
+            lengths_sum += _lengths[i]
+            if _offsets is not None:
+                torch._check(
+                    _offsets[i] + _lengths[i] <= ragged_dim_size,
+                    lambda: "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension",
+                )
+        torch._check(lengths_sum <= ragged_dim_size)
+
+        if _offsets is not None:
+            for i in range(len(_offsets)):
+                torch._check(_offsets[i] >= 0)
+                torch._check(_offsets[i] <= ragged_dim_size)
+
+    if lengths is None:
+        lengths_scalars = offsets.diff().tolist()
+        _torch_check(lengths_scalars)
+
+        return torch.split(values, lengths_scalars, dim=(ragged_idx - 1))
+
+    if ragged_idx <= 0:
+        raise RuntimeError(
+            "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
+        )
+
+    lengths_scalars = lengths.tolist()
+    offsets_scalars = offsets.tolist()
+
+    _torch_check(lengths_scalars, offsets_scalars)
+
+    return [
+        torch.narrow(
+            values,
+            dim=(ragged_idx - 1),
+            start=offsets_scalars[i],
+            length=lengths_scalars[i],
+        )
+        for i in range(lengths.shape[0])
+    ]
+
+
+@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
+def squeeze_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    values = inp._values
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(inp._size), new_kwargs["dim"], inp._ragged_idx, "squeeze"
+    )
+    return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt_all, dim: any")
+def unsqueeze_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    values = inp._values
+
+    # Account for collapsed jagged dim
+    dim = new_kwargs["dim"]
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(inp._size) + 1, dim, inp._ragged_idx, "unsqueeze", allow_ragged_dim=True
+    )
+
+    # ragged_idx changes if a dimension is added before it
+    output_kwargs = extract_kwargs(inp)
+    if new_kwargs["dim"] <= inp._ragged_idx - 1:
+        output_kwargs["_ragged_idx"] += 1
+
+    return NestedTensor(func(values, **new_kwargs), **output_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
+def cat_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    tensors = new_kwargs.pop("tensors")
+
+    # Convert any non-nested to nested
+    nested = [t for t in tensors if t.is_nested]
+    assert len(nested) > 0
+    first = nested[0]
+    tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
+
+    # Account for collapsed jagged dim
+    dim = new_kwargs["dim"]
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(first.shape), dim, first._ragged_idx, "cat"
+    )
+
+    return NestedTensor(
+        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
+    )
+
+
+@register_jagged_func(torch.ops.aten.matmul.default, "self: any, other: any")
+def matmul_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    def _unbind_impl(a, b):
+        return [
+            func(a_comp, b_comp) for (a_comp, b_comp) in zip(a.unbind(), b.unbind())
+        ]
+
+    def _padded_impl(a, b):
+        if a.is_nested:
+            nt = a
+        else:
+            nt = b
+
+        from .nested_tensor import nested_from_padded
+
+        min_seqlen = nt._maybe_min_seqlen
+        max_seqlen = nt._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = nt._values.shape[nt._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        padded_shape = (
+            *nt.shape[: nt._ragged_idx],
+            padded_max_S,
+            *nt.shape[nt._ragged_idx + 1 :],
+        )
+        padded_nt = nt.to_padded_tensor(0.0, output_size=padded_shape)
+        if a.is_nested:
+            padded_t = func(padded_nt, b)
+        else:
+            padded_t = func(a, padded_nt)
+        return nested_from_padded(
+            padded_t,
+            offsets=nt._offsets,
+            ragged_idx=nt._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+    # TODO: Back these with proper kernels (e.g. grouped GEMM)
+    # NJT x dense
+    if inp.is_nested and not other.is_nested:
+        # (B, j1, D) x (B, D, E) => (B, j1, E)
+        if (
+            inp.dim() >= 3
+            and inp.dim() == other.dim()
+            and inp._ragged_idx < inp.dim() - 1
+        ):
+            # convert to padded for this
+            return _padded_impl(inp, other)
+        # Support broadcasting the dense:
+        # (B, j1, D) x (D, E) => (B, j1, E)
+        # (B, j1, D, E) x (E, F) => (B, j1, D, F)
+        # etc.
+        elif (
+            other.dim() == 2
+            and inp.dim() > other.dim()
+            and inp._ragged_idx < inp.dim() - 1
+        ):
+            return NestedTensor(
+                func(inp._values, other, **new_kwargs), **extract_kwargs(inp)
+            )
+    # Dense x NJT
+    elif not inp.is_nested and other.is_nested:
+        # (B, D, E) x (B, E, j1) => (B, E, j1)
+        if other.dim() >= 3 and other.dim() == inp.dim() and other._ragged_idx >= 2:
+            # convert to padded for this
+            return _padded_impl(inp, other)
+        # Support broadcasting the dense:
+        # (D, E) x (B, E, j1) => (B, D, j1)
+        # (D, E) x (B, E, j1, F) => (B, D, j1, F)
+        # etc.
+        elif inp.dim() == 2 and other.dim() > inp.dim() and other._ragged_idx >= 2:
+            return NestedTensor(
+                func(inp, other._values, **new_kwargs), **extract_kwargs(other)
+            )
+
+    # NJT x NJT
+    elif inp.is_nested and other.is_nested:
+        # Support ragged batch dim:
+        # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F), etc.
+        if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size):
+            return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp))
+        # Support reducing over ragged with dense output:
+        # (B, D, j1) x (B, j1, E) => (B, D, E)
+        elif (
+            inp.dim() == 3
+            and other.dim() == 3
+            and inp._ragged_idx == 2
+            and other._ragged_idx == 1
+            and inp.size(inp._ragged_idx) == other.size(other._ragged_idx)
+        ):
+            # do unbind for this; can't use padded conversion due to j1 in last dim
+            return torch.stack(_unbind_impl(inp, other))
+
+    raise RuntimeError(
+        f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}"
+    )
+
+
+@register_jagged_func(torch.ops.aten.bmm.default, "self: jt_all, mat2: any")
+def bmm_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("mat2")
+
+    if inp.dim() != 3:
+        raise ValueError("bmm(): input must be 3D")
+    if other.dim() != 3:
+        raise ValueError("bmm(): mat2 must be 3D")
+
+    return matmul_default(torch.ops.aten.matmul.default, inp, other)
+
+
+@register_jagged_func(
+    torch.ops.aten.expand.default, "self: jt_all, size: any, implicit: any?"
+)
+def expand_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs["size"]
+
+    assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit"))
+    if not raggedness_matches(inp, size):
+        raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}")
+
+    expand_arg = [-1 if d == inp._ragged_idx else size[d] for d in range(1, inp.dim())]
+    return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
+def expand_as_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    return NestedTensor(func(inp, other._values), **extract_kwargs(other))
+
+
+@register_jagged_func(torch.ops.aten.broadcast_to.default, "self: jt_all, size: any")
+def broadcast_to(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs.pop("size")
+
+    if len(size) <= inp.dim():
+        return inp.expand([*(1 for _ in range(inp.dim() - len(size))), *size])
+
+    raise ValueError(
+        "broadcast_to(): broadcasting to a higher-dim shape is currently not supported "
+        "for nested tensors with the jagged layout"
+    )
+
+
+@register_jagged_func(torch.ops.aten.broadcast_tensors.default, "tensors: any")
+def broadcast_tensors(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    tensors = new_kwargs.pop("tensors")
+    if len(tensors) == 0:
+        raise ValueError("broadcast_tensors(): expected at least one tensor input")
+    if len(tensors) == 1:
+        return tensors[0]
+
+    outs = []
+    broadcast_shape = torch.broadcast_shapes(*(t.shape for t in tensors))
+    # Pull out the first NJT. If broadcast_shapes() worked, the nested ints are compatible.
+    njt = next(t for t in tensors if isinstance(t, NestedTensor))
+    for t in tensors:
+        if t.is_nested:
+            outs.append(t.broadcast_to(broadcast_shape))
+        elif t.dim() < len(broadcast_shape):
+            outs.append(
+                NestedTensor(t.broadcast_to(njt._values.shape), **extract_kwargs(njt))
+            )
+        else:
+            raise ValueError(
+                "broadcast_tensors(): broadcasting nested tensors with dense tensors of equal "
+                "or higher dim is not currently supported"
+            )
+
+    return tuple(outs)
+
+
+@register_jagged_func(
+    torch.ops.aten.where.self, "condition: jt_all, self: any, other: any"
+)
+def where_self(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    condition = new_kwargs.pop("condition")
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    # if the tensors aren't compatible, broadcast_tensors() will let us know
+    condition, inp, other = torch.broadcast_tensors(condition, inp, other)
+
+    return NestedTensor(
+        func(condition._values, inp._values, other._values, **new_kwargs),
+        **extract_kwargs(condition),
+    )
+
+
+@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
+def _pin_memory_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
+def is_pinned_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all"
+)
+def is_same_size_default(func, *args, **kwargs):
+    return args[0]._size == args[1]._size
+
+
+def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # some ops use dim=None to indicate a full reduction; some use an empty dim list
+    full_reduction = new_kwargs["dim"] is None or (
+        isinstance(new_kwargs["dim"], (tuple, list)) and len(new_kwargs["dim"]) == 0
+    )
+    if full_reduction:
+        out = func(inp._values, **new_kwargs)
+        if new_kwargs.get("keepdim", False):
+            if isinstance(out, (tuple, list)):
+                # some ops return multiple things; unsqueeze all of them
+                out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out)
+            else:
+                out = out.unsqueeze(inp._ragged_idx)
+        return out
+
+    # some ops support lists of dims; some don't
+    dim_to_convert = new_kwargs["dim"]
+    is_dimlist = isinstance(new_kwargs["dim"], (tuple, list))
+    if not is_dimlist:
+        dim_to_convert = [dim_to_convert]
+
+    (
+        converted_dim,
+        reduce_on_batch,
+        reduce_on_ragged,
+        reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(),
+        dim_to_convert,
+        f"{func_name}",
+        inp._ragged_idx,
+    )
+
+    if not is_dimlist:
+        # convert back from list
+        converted_dim = converted_dim[0]
+    new_kwargs["dim"] = converted_dim
+
+    if reduce_on_ragged and inp._lengths is not None:
+        raise RuntimeError(
+            f"{func_name}(): reducing across the ragged dimension is not supported "
+            "for non-contiguous nested tensors with holes"
+        )
+
+    from torch.utils._pytree import tree_map
+
+    # raggedness reduced away --> return dense tensor
+    if reduce_on_ragged:
+        # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
+        if reduce_on_batch:
+            # no need to read offsets --> apply sum directly on values
+            out = func(inp._values, **new_kwargs)
+            if new_kwargs.get("keepdim", False):
+                # some ops return multiple things; unsqueeze all of them
+                out = tree_map(lambda o: o.unsqueeze(0), out)
+            return out
+        else:
+            # invalid reduction cases: (ragged, non-batch), etc.
+            if reduce_on_non_batch:
+                raise RuntimeError(
+                    f"{func_name}(): reducing along a ragged and non-batch dimension "
+                    "is not supported for nested tensors"
+                )
+
+            # reduction cases: (ragged)
+            # convert to padded dense and reduce
+            new_kwargs.pop("dim")
+            dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx
+            return func(
+                inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs
+            )
+    # raggedness preserved --> return nested tensor
+    else:
+        # invalid reduction cases: (batch), (batch, non-batch), etc.
+        if reduce_on_batch:
+            raise RuntimeError(
+                f"{func_name}(): reducing along the batch dimension but not "
+                "the ragged dimension is not supported for nested tensors"
+            )
+
+        # reduction cases: (non-batch), (non-batch, non-batch), etc.
+        # apply sum directly on values
+        out = func(inp._values, **new_kwargs)
+        out_kwargs = extract_kwargs(inp)
+        if not new_kwargs.get("keepdim", False):
+            # dims are reduced away -> ragged_idx of output needs to be reevaluated
+            dimlist = (
+                new_kwargs["dim"]
+                if isinstance(new_kwargs["dim"], (tuple, list))
+                else [new_kwargs["dim"]]
+            )
+            for d in dimlist:
+                # adjust for all dims reduced before the ragged dim
+                if d < inp._ragged_idx - 1:
+                    out_kwargs["_ragged_idx"] -= 1
+
+        # some ops return multiple things; wrap each of them as an NJT
+        return tree_map(lambda o: NestedTensor(o, **out_kwargs), out)
+
+
+@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?")
+def sum_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.sum.dim_IntList,
+    "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
+)
+def sum_dim_IntList(func, *args, **kwargs):
+    return _apply_reduction(func, "sum", 0, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
+)
+def transpose_int(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    from torch._prims_common import canonicalize_dims
+
+    inp = new_kwargs.pop("input")
+    dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
+
+    # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
+    # instead of 1, although the internal Flash and mem-effn implementations will
+    # use the inputs with raggedness in dim 1.
+    if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
+        if dim0 == 0 or dim1 == 0:
+            raise ValueError(
+                "Transpose is not supported on the batch dimension for jagged NT"
+            )
+        if dim0 == inp._ragged_idx:
+            to_dim = dim1
+        else:
+            to_dim = dim0
+        inp_kwargs = extract_kwargs(inp)
+        inp_kwargs["_ragged_idx"] = to_dim
+        return NestedTensor(
+            inp.values().transpose(
+                _outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
+                _outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
+            ),
+            **inp_kwargs,
+        )
+
+    new_kwargs["dim0"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim0"], inp._ragged_idx, "transpose"
+    )
+    new_kwargs["dim1"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim1"], inp._ragged_idx, "transpose"
+    )
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
+def permute_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    dims = new_kwargs.pop("dims")
+    inp_kwargs = extract_kwargs(inp)
+    inp_dim = len(inp._size)
+
+    # The first two checks are the same as the checks in the normal permute implementation
+    if inp_dim != len(dims):
+        raise ValueError(
+            f"permute(): number of dimensions in the tensor input ({inp_dim}) "
+            + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
+        )
+
+    from torch._prims_common import canonicalize_dims
+
+    canonicalized_dims = canonicalize_dims(inp_dim, dims)
+
+    if len(canonicalized_dims) != len(set(canonicalized_dims)):
+        raise ValueError("permute(): duplicate dims are not allowed.")
+
+    if inp._lengths is not None:
+        raise ValueError(
+            "permute(): not supported on jagged layout nested tensor with holes"
+        )
+    if canonicalized_dims[0] != 0:
+        raise ValueError(
+            "Permute is not supported on the batch dimension for jagged NT"
+        )
+    inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
+    inner_dims = [
+        _outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
+        for dim in canonicalized_dims[1:]
+    ]
+    new_kwargs["dims"] = inner_dims
+    return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
+
+
+@register_jagged_func(
+    [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
+    "self: jt_all, size: any",
+)
+def view_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs.pop("size")
+
+    if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size):
+        raise RuntimeError(
+            f"view(): does not support ragged_idx != 1 except when inp._size == size. "
+            f"inp._size is ({inp._size}) and size is ({size})."
+        )
+
+    # Ensure specified size still includes batch and ragged dims
+    if len(size) < 3 or not raggedness_matches(inp, size):
+        raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
+
+    # outer size: the size of the NT, e.g. [3, j0, 10]
+    # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8])
+    # this function gets inner_size[inner_idx] for a given inner_idx.
+    #
+    # example: for outer size [a, b, c, j0, d, e, f]
+    #                         assume that j0 is ragged, other are concrete integers
+    #                         and ragged_idx=3
+    # inner size will be      [b, c, inp._values.size(ragged_idx), d, e, f]
+    # therefore:
+    #    inner_size[0] = outer_size[1]
+    #    inner_size[1] = outer_size[2]
+    #    inner_size[0] = inp._values.size(ragged_idx - 1)
+    #    inner_size[3] = outer_size[4]
+    #    inner_size[4] = outer_size[5]
+    def get_inner_size(inner_idx):
+        nonlocal inp, size
+        if inner_idx == inp._ragged_idx - 1:
+            return inp._values.size(inner_idx)
+        else:
+            return size[inner_idx + 1]
+
+    inner_size = [get_inner_size(i) for i in range(len(size) - 1)]
+
+    # Preserve inference-mode-ness of input.
+    # TODO: Do this for all other views!
+    with torch.inference_mode(inp.is_inference()):
+        return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.native_layer_norm.default,
+    "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
+)
+def native_layer_norm_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if inp.dim() <= 2:
+        raise RuntimeError(
+            "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions"
+        )
+
+    normalized_shape = new_kwargs["normalized_shape"]
+    ragged_size = inp.shape[inp._ragged_idx]
+
+    num_dims_not_normalized = inp.dim() - len(normalized_shape)
+
+    if (
+        num_dims_not_normalized == 0
+    ):  # error if trying to normalize over the batch dimension
+        raise RuntimeError(
+            "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor"
+        )
+
+    if ragged_size in normalized_shape and inp._lengths is not None:
+        raise RuntimeError(
+            "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor"
+        )
+
+    if (
+        ragged_size in normalized_shape
+    ):  # special handling for normalizing over the ragged dimension
+        padded_input = torch.ops.aten._jagged_to_padded_dense_forward(
+            inp._values.flatten(
+                start_dim=inp._ragged_idx
+            ),  # _jagged_to_padded_dense_forward requires values to be a 2D tensor
+            [inp._offsets],
+            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+        )
+
+        padded_mask = torch.ops.aten._jagged_to_padded_dense_forward(
+            torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype),
+            [inp._offsets],
+            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+        ).expand(
+            padded_input.shape
+        )  # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor)
+
+        ragged_lengths = (
+            inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2]
+        )  # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize)
+
+        mean = (
+            torch.sum(
+                padded_input,
+                dim=(1, 2),
+                keepdim=True,
+            )
+            / ragged_lengths
+        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
+
+        padded_normalized = (
+            (padded_input - mean) * padded_mask
+        )  # mask elements outside of the ragged dimension size for correct variance calculation
+
+        variance = (
+            torch.sum(
+                torch.square(padded_normalized),
+                dim=(1, 2),
+                keepdim=True,
+            )
+            / ragged_lengths
+        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
+
+        std = torch.sqrt(variance + new_kwargs["eps"])
+        padded_layer_norm = padded_normalized / std
+
+        jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward(
+            padded_layer_norm,
+            [inp._offsets],
+            total_L=inp._values.shape[
+                0
+            ],  # providing this parameter helps avoid a GPU/CPU sync
+        ).unflatten(
+            -1, inp.shape[inp._ragged_idx + 1 :]
+        )  # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H)
+
+        return (
+            NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)),
+            mean,
+            std,
+        )
+
+    output, mean, std = func(inp._values, **new_kwargs)
+    return (NestedTensor(output, **extract_kwargs(inp)), mean, std)
+
+
+@register_jagged_func(
+    torch.ops.aten.native_layer_norm_backward.default,
+    "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
+)
+def native_layer_norm_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_out = new_kwargs.pop("grad_out")
+    inp = new_kwargs.pop("input")
+    d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs)
+    if d_input is None:
+        return (None, d_gamma, d_beta)
+
+    return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta)
+
+
+@register_jagged_func(torch.ops.aten.select.int, "self: jt_all, dim: any, index: any")
+def select_int(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True
+    )
+
+    # handle batch dim slicing via unbind() for now
+    # TODO: make this more efficient
+    if operating_on_batch:
+        return inp.unbind()[new_kwargs["index"]]
+
+    if inp._lengths is not None:
+        raise ValueError(
+            "select(): not yet supported on dim != 0 for non-contiguous nested tensor with holes"
+        )
+
+    # if selecting before the ragged dim, adjust output ragged_idx
+    out_kwargs = extract_kwargs(inp)
+    if new_kwargs["dim"] < inp._ragged_idx - 1:
+        out_kwargs["_ragged_idx"] -= 1
+
+    return NestedTensor(func(inp._values, **new_kwargs), **out_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.slice.Tensor,
+    "self: jt, dim: any?, start: any?, end: any?, step: any?",
+)
+def slice_tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "slice"
+    )
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.index_put.default,
+    "input: jt_all, indices: any, values: t, accumulate: any?",
+)
+@register_jagged_func(
+    torch.ops.aten.index_put_.default,
+    "input: jt_all, indices: any, values: t, accumulate: any?",
+)
+def index_put_(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp: NestedTensor = new_kwargs.pop("input")
+
+    # For index_put_ to work, we add together the indices of the ragged dimension
+    # and the batch dimension, adding the offsets of each ragged dimension to its
+    # indices
+
+    indices = new_kwargs.pop("indices")
+
+    assert len(indices) <= inp.dim()
+
+    if len(indices) < inp._ragged_idx + 1:
+        if not inp.is_contiguous():
+            raise RuntimeError(
+                "index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs"
+            )
+        # Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func
+        from .nested_tensor import nested_from_padded
+
+        min_seqlen = inp._maybe_min_seqlen
+        max_seqlen = inp._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = inp._values.shape[inp._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        padded_shape = (
+            *inp.shape[: inp._ragged_idx],
+            padded_max_S,
+            *inp.shape[inp._ragged_idx + 1 :],
+        )
+        padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape)
+        new_njt = nested_from_padded(
+            func(padded_inp, indices, **new_kwargs),
+            offsets=inp._offsets,
+            ragged_idx=inp._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+        if func is torch.ops.aten.index_put_.default:
+            inp._values.copy_(new_njt.values())
+            return inp
+        return new_njt
+
+    # We can run on the underlying values directly
+
+    # Validate indices
+    if inp.lengths() is None:
+        lengths = inp.offsets().diff()
+    else:
+        lengths = inp.lengths()
+    torch._assert_async(
+        # pyrefly: ignore [no-matching-overload]
+        torch.all(indices[inp._ragged_idx] < lengths),
+        "Some indices in the ragged dimension are out of bounds!",
+    )
+
+    # Recompute indices for _values
+    ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx]
+    func_indices = (
+        # before ragged dim
+        indices[1 : inp._ragged_idx]
+        # ragged dim (combined with batch)
+        + [ragged_indices]
+        # after ragged dim
+        + indices[inp._ragged_idx + 1 :]
+    )
+
+    if func is torch.ops.aten.index_put_.default:
+        inp._values = func(inp._values, func_indices, **new_kwargs)
+        return inp
+
+    return NestedTensor(
+        func(inp._values, func_indices, **new_kwargs),
+        **extract_kwargs(inp),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.convolution.default,
+    "input: jt, weight: t, bias: t?, stride: any, padding: any, "
+    "dilation: any, transposed: any, output_padding: any, groups: any",
+)
+def convolution_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?"
+)
+def mean_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs["input"]
+    (_, reduce_on_batch, reduce_on_ragged, reduce_on_non_batch) = _wrap_jagged_dims(
+        inp.dim(),
+        new_kwargs["dim"],
+        "mean",
+        inp._ragged_idx,
+    )
+
+    if reduce_on_ragged and not reduce_on_batch:
+        assert not reduce_on_non_batch
+        # calculate an intermediate sum and leave the dim in for normalization purposes
+        keepdim = new_kwargs["keepdim"]
+        new_kwargs["keepdim"] = True
+        intermediate_sum = _apply_reduction(
+            torch.ops.aten.sum.dim_IntList, "mean", 0, **new_kwargs
+        )
+
+        # normalize by sequence lengths
+        lengths = inp._lengths if inp._lengths is not None else inp._offsets.diff()
+        for _ in range(intermediate_sum.dim() - 1):
+            lengths = lengths.unsqueeze(-1)
+        out = intermediate_sum / lengths
+        if not keepdim:
+            out = out.squeeze(inp._ragged_idx)
+        return out
+
+    # at this point, we're just redispatching on the values buffer
+    # since we expect it to be unused, specify a weird intermediate value to
+    # hopefully make errors obvious
+    intermediate_value = 0.42
+    return _apply_reduction(func, "mean", intermediate_value, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.mean.default, "self: jt_all, dtype: any?")
+def mean_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.any.dims, "self: jt_all, dim: any?, keepdim: any?")
+def any_dims(func, *args, **kwargs):
+    return _apply_reduction(func, "any", False, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.any.dim, "self: jt_all, dim: any, keepdim: any?")
+def any_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # wrap dim in list to redispatch to dims overload
+    new_kwargs["dim"] = [new_kwargs["dim"]]
+    return any_dims(torch.ops.aten.any.dims, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.all.dims, "self: jt_all, dim: any?, keepdim: any?")
+def all_dims(func, *args, **kwargs):
+    return _apply_reduction(func, "all", True, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.all.dim, "self: jt_all, dim: any, keepdim: any?")
+def all_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # wrap dim in list to redispatch to dims overload
+    new_kwargs["dim"] = [new_kwargs["dim"]]
+    return all_dims(torch.ops.aten.all.dims, **new_kwargs)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.all.default,
+        torch.ops.aten.any.default,
+        torch.ops.aten.max.default,
+        torch.ops.aten.min.default,
+    ],
+    "self: jt_all",
+)
+def all_any_max_min_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    [torch.ops.aten._is_all_true.default, torch.ops.aten._is_any_true.default],
+    "self: jt_all",
+)
+def _is_true_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return func(inp._values)
+
+
+@register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?")
+def min_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_max = _get_padding_value(dtype, "max")
+    return _apply_reduction(func, "min", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.max.dim, "self: jt_all, dim: any, keepdim: any?")
+def max_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_min = _get_padding_value(dtype, "min")
+    return _apply_reduction(func, "max", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.amin.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def amin_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_max = _get_padding_value(dtype, "max")
+    return _apply_reduction(func, "amin", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.amax.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def amax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_min = _get_padding_value(dtype, "min")
+    return _apply_reduction(func, "amax", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.argmin.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def argmin_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_max = _get_padding_value(dtype, "max")
+    return _apply_reduction(func, "argmin", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.argmax.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def argmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_min = _get_padding_value(dtype, "min")
+    return _apply_reduction(func, "argmax", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.value_selecting_reduction_backward.default,
+    "grad: jt_all, dim: any, indices: jt_all, sizes: any, keepdim: any",
+)
+def value_selecting_reduction_backward_default(func, *args, **kwargs):
+    from torch.fx.experimental.symbolic_shapes import is_nested_int
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    grad = new_kwargs.pop("grad")
+    new_kwargs["grad"] = grad._values
+    indices = new_kwargs.pop("indices")
+    new_kwargs["indices"] = indices._values
+    # should always succeed; sizes should contain a nested int
+    ragged_idx = next(i for i, s in enumerate(new_kwargs["sizes"]) if is_nested_int(s))
+    # convert dim -> values-space dim
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(new_kwargs["sizes"]),
+        new_kwargs["dim"],
+        ragged_idx,
+        "value_selecting_reduction_backward",
+    )
+    # convert saved NJT sizes -> values-space sizes
+    sizes = new_kwargs.pop("sizes")
+    sizes[ragged_idx] = indices._values.size(indices._ragged_idx - 1)
+    sizes = sizes[1:]
+    new_kwargs["sizes"] = sizes
+
+    output_kwargs = extract_kwargs(indices)
+    output_kwargs["_ragged_idx"] = ragged_idx
+
+    return NestedTensor(func(**new_kwargs), **output_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
+def stack_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # guaranteed this is non-empty if we got here
+    tensors = new_kwargs.pop("tensors")
+    for t in tensors:
+        if not isinstance(t, NestedTensor):
+            raise RuntimeError("stack(): expected all nested tensors inputs")
+
+        if t.dim() != tensors[0].dim():
+            raise RuntimeError(
+                "stack(): expected all nested tensors to have the same dim"
+            )
+
+        if not raggedness_matches(t, tensors[0].shape):
+            raise RuntimeError(
+                "stack(): expected all nested tensors to have the same nested structure"
+            )
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        tensors[0].dim() + 1, new_kwargs["dim"], tensors[0]._ragged_idx, "stack"
+    )
+
+    return NestedTensor(
+        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.embedding.default,
+    "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
+)
+def embedding_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # guaranteed this is non-empty if we got here
+    indices = new_kwargs.pop("indices")
+    weight = new_kwargs.pop("weight")
+
+    return NestedTensor(
+        func(weight, indices._values, **new_kwargs), **extract_kwargs(indices)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.embedding_dense_backward.default,
+    "grad_output: jt, indices: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any",
+)
+def embedding_dense_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    indices = new_kwargs.pop("indices")
+    grad_output = new_kwargs.pop("grad_output")
+    return func(grad_output._values, indices._values, **new_kwargs)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.values.default,
+        torch.ops.aten._nested_get_values.default,
+    ],
+    "self: jt_all",
+)
+def values_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # TODO: Handle inference mode properly.
+    # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292
+    return inp._values.detach()
+
+
+@register_jagged_func(torch.ops.aten.all.default, "self: jt_all")
+def all_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values)
+
+
+@register_jagged_func(
+    torch.ops.aten.to_padded_tensor.default,
+    "self: jt_all, padding: any, output_size: any?",
+)
+def to_padded_tensor_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if inp._lengths is not None:
+        raise RuntimeError(
+            "to_padded_tensor(): not supported for nested tensors with holes"
+        )
+
+    # TODO: Handle the rest of output_size
+    output_size = new_kwargs["output_size"]
+    if output_size is not None:
+        max_seq_len = output_size[inp._ragged_idx]
+    else:
+        max_seq_len = (
+            inp._max_seqlen
+            if inp._max_seqlen_tensor is not None
+            else inp._values.size(0)
+        )
+
+    # only 2D values with ragged packed dim=0 is supported by the underlying FBGEMM
+    # kernel so do shape gymnastics if needed
+    values = inp.values()
+    if inp._ragged_idx > 1:
+        values = values.transpose(inp._ragged_idx - 1, 0)
+    values_shape = values.shape
+    if values.dim() > 2:
+        values = values.flatten(start_dim=1)
+    elif values.dim() == 1:
+        values = values.unsqueeze(-1)
+
+    # NB: The CUDA kernel for jagged -> padded dense conversion does not support
+    # integer / bool types; work around this by casting to half.
+    is_bool = values.dtype is torch.bool
+    if is_bool and values.is_cuda:
+        values = values.to(torch.half)
+    padded_out = torch.ops.aten._jagged_to_padded_dense_forward(
+        values,
+        [inp._offsets],
+        [max_seq_len],
+        new_kwargs["padding"],
+    )
+    if is_bool and padded_out.is_cuda:
+        padded_out = padded_out.to(torch.bool)
+
+    # shape gymnastics part 2
+    if len(values_shape) > 2:
+        padded_out = padded_out.unflatten(-1, values_shape[1:])
+    elif len(values_shape) == 1:
+        padded_out = padded_out.squeeze(-1)
+    if inp._ragged_idx > 1:
+        padded_out = padded_out.transpose(inp._ragged_idx, 1)
+
+    return padded_out
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_from_padded_tensor.default,
+    "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?",
+)
+def _nested_from_padded_tensor_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    padded, offsets = new_kwargs["padded"], new_kwargs["offsets"]
+    ragged_idx = new_kwargs.get("ragged_idx", 1)
+
+    # only 3D padded with ragged packed dim=0 is supported by the underlying FBGEMM
+    # kernel so do shape gymnastics
+    if ragged_idx > 1:
+        padded = padded.transpose(ragged_idx, 1)
+    padded_ragged_dim1_shape = padded.shape
+    if padded.dim() > 3:
+        padded = padded.flatten(start_dim=2)
+    elif padded.dim() < 3:
+        padded = padded.unsqueeze(-1)
+
+    # NB: The CUDA kernel for padded dense -> jagged conversion does not support
+    # integer / bool types; work around this by casting to half.
+    is_bool = padded.dtype is torch.bool
+    if is_bool and padded.is_cuda:
+        padded = padded.to(torch.half)
+    values = torch.ops.aten._padded_dense_to_jagged_forward(
+        padded, [offsets], new_kwargs["sum_S"]
+    )
+    if is_bool and values.is_cuda:
+        values = values.to(torch.bool)
+
+    # shape gymnastics part 2
+    if len(padded_ragged_dim1_shape) > 3:
+        values = values.unflatten(-1, padded_ragged_dim1_shape[2:])
+    elif len(padded_ragged_dim1_shape) < 3:
+        values = values.squeeze(-1)
+    if ragged_idx > 1:
+        values = values.transpose(ragged_idx - 1, 0)
+
+    min_seqlen = new_kwargs["min_seqlen"]
+    max_seqlen = new_kwargs["max_seqlen"]
+    metadata_cache = {}
+    if min_seqlen is not None:
+        metadata_cache["min_seqlen"] = min_seqlen
+    if max_seqlen is not None:
+        metadata_cache["max_seqlen"] = max_seqlen
+
+    return NestedTensor(
+        values,
+        offsets,
+        _ragged_idx=ragged_idx,
+        _metadata_cache=metadata_cache,
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_view_from_jagged.default,
+    "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
+)
+def _nested_view_from_jagged_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    values, offsets, lengths = (
+        new_kwargs["input"],
+        new_kwargs["offsets"],
+        new_kwargs["lengths"],
+    )
+    ragged_idx = new_kwargs["ragged_idx"]
+    min_seqlen = new_kwargs["min_seqlen"]
+    max_seqlen = new_kwargs["max_seqlen"]
+    metadata_cache = {}
+    if min_seqlen is not None:
+        metadata_cache["min_seqlen"] = min_seqlen
+    if max_seqlen is not None:
+        metadata_cache["max_seqlen"] = max_seqlen
+
+    return NestedTensor(
+        values,
+        offsets,
+        lengths=lengths,
+        _ragged_idx=ragged_idx,
+        _metadata_cache=metadata_cache,
+    )
+
+
+@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
+def _nested_get_offsets(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._offsets
+
+
+@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
+def _nested_get_lengths(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._lengths
+
+
+@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
+def _nested_get_ragged_idx(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._ragged_idx
+
+
+@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
+def _nested_get_min_seqlen(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._metadata_cache.get("min_seqlen", None)
+
+
+@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
+def _nested_get_max_seqlen(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._metadata_cache.get("max_seqlen", None)
+
+
+# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
+@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
+def masked_select_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    mask = new_kwargs.pop("mask")
+
+    if inp.ndim > 2:
+        raise RuntimeError("masked_select only support 2-D selections currently")
+    elif inp.shape != mask.shape:
+        raise RuntimeError(
+            f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}"
+        )
+    res_values = inp._values.masked_select(mask.values())
+    mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0))  # type: ignore[arg-type]
+
+    args = extract_kwargs(inp)
+    args["offsets"] = mask_cumsum[inp._offsets]
+    return NestedTensor(
+        values=res_values,
+        **args,
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_select_backward.default,
+    "grad_output: t, self: jt_all, dim: any, index: any",
+)
+def _nested_select_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    grad_output = new_kwargs.pop("grad_output")
+
+    grad_input = torch.zeros_like(inp, dtype=grad_output.dtype)
+    grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output)
+
+    return grad_input
+
+
+@register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any")
+def record_stream_default(func, *args, **kwargs) -> None:
+    inp = args[0]
+    stream = args[1]
+    # ensure all components live until stream computation completes
+    func(inp._values, stream)
+    func(inp._offsets, stream)
+    if inp._lengths is not None:
+        func(inp._lengths, stream)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.new_empty.default,
+        torch.ops.aten.new_zeros.default,
+        torch.ops.aten.new_ones.default,
+    ],
+    "self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?",
+)
+def new_empty_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if len(new_kwargs["size"]) == 0:
+        return func(inp._values, **new_kwargs)
+
+    raise RuntimeError("new_empty() not supported for NJT with shape != ()")
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.elu_backward.default,
+        torch.ops.aten.hardshrink_backward.default,
+        torch.ops.aten.hardsigmoid_backward.default,
+        torch.ops.aten.hardtanh_backward.default,
+        torch.ops.aten.softplus_backward.default,
+        torch.ops.aten.softshrink_backward.default,
+    ],
+    "self: jt_all, ...",
+)
+def activation_backward(func, *args, **kwargs):
+    # first NJT arg is expected to be grad_output
+    grad_output = next(arg for arg in args if isinstance(arg, NestedTensor))
+    return NestedTensor(
+        func(
+            *(arg._values if isinstance(arg, NestedTensor) else arg for arg in args),
+            **kwargs,
+        ),
+        **extract_kwargs(grad_output),
+    )
+
+
+@register_jagged_func(torch.ops.aten.fill.Scalar, "self: jt_all, value: any")
+def fill_Scalar(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any")
+def fill__Scalar(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    func(inp._values, **new_kwargs)
+    return inp
+
+
+@register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all")
+def frexp_Tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    output_kwargs = extract_kwargs(inp)
+
+    mantissa, exponent = func(inp._values)
+    return NestedTensor(mantissa, **output_kwargs), NestedTensor(
+        exponent, **output_kwargs
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.matmul_backward.default,
+    "grad: any, self: any, other: any, mask: any",
+)
+def matmul_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    grad = new_kwargs.pop("grad")
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+    grad_input_mask = new_kwargs.pop("mask")
+
+    if grad is None:
+        return (None, None)
+
+    grad_self = None
+    if grad_input_mask[0]:
+        grad_self = torch.matmul(grad, other.transpose(-1, -2))
+
+    grad_other = None
+    if grad_input_mask[1]:
+        grad_other = torch.matmul(inp.transpose(-1, -2), grad)
+
+    return (grad_self, grad_other)
+
+
+# Make the dummy available on the C++ side.
+@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
+def _nested_get_jagged_dummy(func, *args, **kwargs):
+    from torch.nested._internal.nested_tensor import _nt_view_dummy
+
+    return _nt_view_dummy()
+
+
+with torch.library._scoped_library("aten", "IMPL") as aten:
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py
new file mode 100644
index 0000000000000000000000000000000000000000..328702ede37462cf880503e575b6d722b7ba4a40
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py
@@ -0,0 +1,933 @@
+# mypy: allow-untyped-defs
+import logging
+
+import torch
+import torch.nn
+import torch.nn.functional as F
+from torch.backends.cuda import (
+    can_use_cudnn_attention,
+    can_use_efficient_attention,
+    can_use_flash_attention,
+    cudnn_sdp_enabled,
+    flash_sdp_enabled,
+    math_sdp_enabled,
+    mem_efficient_sdp_enabled,
+    SDPAParams,
+)
+from torch.nn.attention import SDPBackend
+
+from .nested_tensor import NestedTensor
+
+
+log = logging.getLogger(__name__)
+
+
+def _validate_sdpa_input(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: torch.Tensor | None = None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+) -> None:
+    if (
+        not isinstance(query, NestedTensor)
+        or not isinstance(key, NestedTensor)
+        or not isinstance(value, NestedTensor)
+    ):
+        raise ValueError(
+            f"Expected query, key, and value to be nested tensors, "
+            f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
+            f"and value.is_nested: {value.is_nested} instead."
+        )
+    if query.dtype != key.dtype or query.dtype != value.dtype:
+        raise ValueError(
+            f"Expected query, key, and value to have the same dtype, "
+            f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
+            f"and value.dtype: {value.dtype} instead."
+        )
+    if query.device != key.device or query.device != value.device:
+        raise ValueError(
+            f"Expected query, key, and value to have the same device type, "
+            f"but got query.device: {query.device}, key.device: {key.device}, "
+            f"and value.device: {value.device} instead."
+        )
+    if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
+        raise ValueError(
+            f"Expected query, key, and value to all be  at least 3 dimensional, but got query.dim: "
+            f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
+        )
+    if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
+        raise ValueError(
+            f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
+            f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
+        )
+    if attn_mask is not None:
+        # TODO: Figure out whether masks are actually supported for this layout or not
+        raise ValueError("Masks are not yet supported!")
+        if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
+            raise ValueError(
+                f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
+                f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
+            )
+
+
+def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
+    # This is expected to be called after check_tensor_shapes ensuring that the
+    # size() calls won't error since the inputs are all 4 dimensional
+    q_batch_size = params.query.size(0)
+    k_batch_size = params.key.size(0)
+    v_batch_size = params.value.size(0)
+
+    # num_heads logic for nested input is checked in
+    # check_for_seq_len_0_nested_tensor as there is handling there to make sure
+    # num_heads is not ragged
+    return q_batch_size == k_batch_size and q_batch_size == v_batch_size
+
+
+def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
+    max_size = 256
+    query_size_last = params.query.size(-1)
+    key_size_last = params.key.size(-1)
+    value_size_last = params.value.size(-1)
+    same_head_dim_size = (
+        query_size_last == key_size_last and query_size_last == value_size_last
+    )
+    if not (
+        same_head_dim_size
+        and (query_size_last % 8 == 0)
+        and (query_size_last <= max_size)
+    ):
+        if debug:
+            log.warning(
+                "For NestedTensor inputs, Flash attention requires q,k,v to have the same "
+                "last dimension and to be a multiple of 8 and less than or equal to 256. "
+                "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
+                query_size_last,
+                key_size_last,
+                value_size_last,
+            )
+        return False
+    return True
+
+
+def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool:
+    max_size = 128
+    query_size_last = params.query.size(-1)
+    key_size_last = params.key.size(-1)
+    value_size_last = params.value.size(-1)
+    same_head_dim_size = (
+        query_size_last == key_size_last and query_size_last == value_size_last
+    )
+    if not (
+        same_head_dim_size
+        and (query_size_last % 8 == 0)
+        and (query_size_last <= max_size)
+    ):
+        if debug:
+            log.warning(
+                "For NestedTensor inputs, cuDNN attention requires q,k,v to have the same "
+                "last dimension and to be a multiple of 8 and less than or equal to 128. "
+                "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
+                query_size_last,
+                key_size_last,
+                value_size_last,
+            )
+        return False
+    return True
+
+
+def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+    param: torch.Tensor, param_name: str, debug=False
+) -> bool:
+    assert isinstance(param, NestedTensor), "param should be a jagged NT"
+
+    if param._ragged_idx == 1:
+        # num_head_dims is ragged
+        if debug:
+            log.warning(
+                "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
+                param_name,
+            )
+        return False
+
+    # This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
+    if param._get_min_seqlen() == 0:
+        if debug:
+            log.warning(
+                "Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
+                param_name,
+            )
+        return False
+
+    return True
+
+
+def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
+    max_size = max(q_size, k_size, v_size)
+    if (
+        (q_size != max_size and q_size != 1)
+        or (k_size != max_size and k_size != 1)
+        or (v_size != max_size and v_size != 1)
+    ):
+        if debug:
+            log.warning(
+                "Both fused kernels require query, key and value to have broadcastable %s, "
+                "got Query %s %d, Key %s %d, Value %s %d instead.",
+                param_name,
+                param_name,
+                q_size,
+                param_name,
+                k_size,
+                param_name,
+                v_size,
+            )
+        return False
+    return True
+
+
+def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
+    # When this function is called we are assured that the nt is dim==4
+    q_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.query, "query", debug
+        )
+        if params.query.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not q_is_safe:
+        return False
+
+    k_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.key, "key", debug
+        )
+        if params.key.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not k_is_safe:
+        return False
+
+    v_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.value, "value", debug
+        )
+        if params.value.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not v_is_safe:
+        return False
+
+    # We now know none of the inputs have ragged num_heads, so we can safely
+    # access .size(1)
+    q_num_heads = params.query.size(1)
+    k_num_heads = params.key.size(1)
+    v_num_heads = params.value.size(1)
+    same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
+
+    if not same_num_heads:
+        if (
+            params.query.requires_grad
+            or params.key.requires_grad
+            or params.value.requires_grad
+        ):
+            if debug:
+                log.warning(
+                    "Both fused kernels do not support training with broadcasted NT inputs."
+                )
+            return False
+        return _try_broadcast_param_size(
+            q_num_heads, k_num_heads, v_num_heads, "num heads", debug
+        )
+    return True
+
+
+def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    constraints = (
+        _check_batch_size_nested,
+        _check_head_dim_size_flash_nested,
+        _check_for_seq_len_0_nested,
+    )
+    for constraint in constraints:
+        if not constraint(params, debug):
+            return False
+    return True
+
+
+def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    constraints = (
+        _check_batch_size_nested,
+        _check_for_seq_len_0_nested,
+    )
+    for constraint in constraints:
+        if not constraint(params, debug):
+            return False
+    return True
+
+
+def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    if (
+        not params.query.transpose(1, 2).is_contiguous()
+        or not params.key.transpose(1, 2).is_contiguous()
+        or not params.value.transpose(1, 2).is_contiguous()
+    ):
+        if debug:
+            log.warning(
+                "If inputs are nested tensors they must be contiguous after transposing."
+            )
+        return False
+    if params.is_causal:
+        if debug:
+            log.warning(
+                "Nested tensors for query / key are not supported when is_causal=True."
+            )
+        return False
+    return True
+
+
+def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
+    if (
+        not flash_sdp_enabled()
+        and not mem_efficient_sdp_enabled()
+        and not math_sdp_enabled()
+        and not cudnn_sdp_enabled()
+    ):
+        return SDPBackend.ERROR
+
+    ordering = (
+        SDPBackend.FLASH_ATTENTION,
+        SDPBackend.EFFICIENT_ATTENTION,
+        SDPBackend.MATH,
+        SDPBackend.CUDNN_ATTENTION,
+    )
+
+    params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
+
+    for backend in ordering:
+        if backend == SDPBackend.CUDNN_ATTENTION:
+            if can_use_cudnn_attention(params):
+                return SDPBackend.CUDNN_ATTENTION
+        if backend == SDPBackend.FLASH_ATTENTION:
+            if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
+                return SDPBackend.FLASH_ATTENTION
+        if backend == SDPBackend.EFFICIENT_ATTENTION:
+            if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
+                params
+            ):
+                return SDPBackend.EFFICIENT_ATTENTION
+        if backend == SDPBackend.MATH:
+            if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
+                return SDPBackend.MATH
+
+    log.warning("Memory efficient kernel not used because:")
+    can_use_efficient_attention(params, debug=True)
+    _can_use_efficient_sdpa_jagged(params, debug=True)
+    log.warning("Flash attention kernel not used because:")
+    can_use_flash_attention(params, debug=True)
+    _can_use_flash_sdpa_jagged(params, debug=True)
+    log.warning("Math attention kernel not used because:")
+    _can_use_math_sdpa_jagged(params, debug=True)
+    log.warning("cuDNN attention kernel not used because:")
+    can_use_cudnn_attention(params, debug=True)
+    return SDPBackend.ERROR
+
+
+def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]:
+    # This function is used to calculate two pieces of metadata that are needed
+    # for use with flash-attention and efficient_attention kernels. They are the
+    # cumulative sequence_length over a batch of sequences and the maximum
+    # sequence length.
+
+    # It returns a tuple of cumulative sequence lengths and the maximum sequence
+    # length, and the last element in the cumulative_sequence_lengths
+    if not isinstance(qkv, NestedTensor):
+        raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
+
+    if qkv.lengths() is None:
+        # TODO: Explore performance impact of copying
+        cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
+        max_seqlen = qkv._get_max_seqlen()
+        n_elem = qkv.values().shape[0]
+    else:
+        # TODO: Explore performance impact of copying
+        cumulative_seqlen = (
+            qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
+        )
+        max_seqlen = qkv._get_max_seqlen()
+        # TODO: Explore performance impact when compiling
+        n_elem = int(cumulative_seqlen[-1].item())
+    return cumulative_seqlen, max_seqlen, n_elem
+
+
+def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor) -> bool:
+    # This function checks if a nested tensor is valid for
+    # use with the flash-attention and efficient_attention kernels without
+    # needing to call contiguous on the nested tensor input.
+    # It checks that the storage offsets' adjacent_differences are a constant
+    # multiple of the previous tensor in the nested tensor and that the strides
+    # are monitonically decreasing. This check is done after calling transpose on
+    # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
+
+    # Returns a boolean indicating if contiguous needs to be called for input
+    assert isinstance(tensor, NestedTensor)
+    offsets = tensor.offsets()
+    strides = tensor._strides
+
+    n_tensors = offsets.size(0) - 1
+    if n_tensors <= 1:
+        return True
+
+    # Check initially that the tensor strides are in strictly descending order
+    prev_stride = strides[1]
+    for stride in strides[2:]:
+        if prev_stride <= stride:
+            # This would mean that the last stride is greater than the seq_len
+            # stride
+            return False
+        prev_stride = stride
+
+    # Congrats you made it!
+    return True
+
+
+def _view_as_dense(
+    tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
+) -> torch.Tensor:
+    if tensor.is_nested:
+        return tensor.values()
+    return tensor.view(Nnz, num_heads, head_dim)
+
+
+# TODO: Next iteration should add test cases and check it works
+# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
+#     # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
+#     # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+#     # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+#     q_batch_size = query.size(0)
+#     k_batch_size = key.size(0)
+#     v_batch_size = value.size(0)
+
+#     output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
+
+#     q_num_heads = query.size(1)
+#     k_num_heads = key.size(1)
+#     v_num_heads = value.size(1)
+
+#     output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
+
+#     head_dim_qk = query.size(3)
+#     head_dim_v = value.size(3)
+
+#     q_t = query.transpose(1, 2)
+#     k_t = key.transpose(1, 2)
+#     v_t = value.transpose(1, 2)
+
+#     # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
+#     # output_batch_size/num_heads then they are 1
+#     q_batch_size_needs_broadcast = q_batch_size != output_batch_size
+#     k_batch_size_needs_broadcast = k_batch_size != output_batch_size
+#     v_batch_size_needs_broadcast = v_batch_size != output_batch_size
+
+#     # If {*}_batch_size_needs_broadcast, then
+#     # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
+#     #     this is because needs_broadcast indicates that the batch_size is 1
+#     #     and hence there is only 1 value for seq_len
+#     # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
+#     # ..., output_batch_size * {*}_t.size(1)]
+#     # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
+
+#     if q_batch_size_needs_broadcast or not q_t.is_nested:
+#         max_seqlen_batch_q = q_t.size(1)
+#         cumulative_sequence_length_q = torch.arange(
+#             0,
+#             (output_batch_size + 1) * max_seqlen_batch_q,
+#             max_seqlen_batch_q,
+#             device=q_t.device,
+#             dtype=torch.int32,
+#         )
+#         Nnz_q = output_batch_size * max_seqlen_batch_q
+#     else:
+#         (
+#             cumulative_sequence_length_q,
+#             max_seqlen_batch_q,
+#             Nnz_q,
+#         ) = _cumulative_and_max_seq_len_nnz(q_t)
+
+#     if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
+#         assert k_t.size(1) == v_t.size(1)
+#         max_seqlen_batch_kv = k_t.size(1)
+#         cumulative_sequence_length_kv = torch.arange(
+#             0,
+#             (output_batch_size + 1) * max_seqlen_batch_kv,
+#             max_seqlen_batch_kv,
+#             device=k_t.device,
+#             dtype=torch.int32,
+#         )
+#         Nnz_kv = output_batch_size * max_seqlen_batch_kv
+#     else:
+#         cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
+#             _cumulative_and_max_seq_len_nnz(v_t)
+#             if k_batch_size_needs_broadcast
+#             else _cumulative_and_max_seq_len_nnz(k_t)
+#         )
+
+#     q_num_heads_needs_broadcast = q_num_heads != output_num_heads
+#     k_num_heads_needs_broadcast = k_num_heads != output_num_heads
+#     v_num_heads_needs_broadcast = v_num_heads != output_num_heads
+
+#     if not q_t.is_nested:
+#         query_buffer_reshaped = q_t.expand(
+#             output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
+#         )
+#         query_buffer_reshaped = query_buffer_reshaped.reshape(
+#             Nnz_q, output_num_heads, head_dim_qk
+#         )
+#     else:
+#         if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
+#             q_t = q_t.contiguous()
+#         # If we are broadcasting then Nnz_q will be the output_batch_size since
+#         # seq_len is 1
+#         effective_batch_size_q = (
+#             output_batch_size if q_batch_size_needs_broadcast else Nnz_q
+#         )
+#         query_buffer_reshaped = _view_as_dense(
+#             q_t, effective_batch_size_q, output_num_heads, head_dim_qk
+#         )
+
+#     # If the physical layout of the NestedTensor's storage
+#     # is not: batch, {seq_len}, num_heads, head_dim then we need
+#     # to call contiguous
+#     if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
+#         k_t = k_t.contiguous()
+#     if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
+#         v_t = v_t.contiguous()
+
+#     effective_batch_size_k = (
+#         output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
+#     )
+#     key_buffer_reshaped = _view_as_dense(
+#         k_t, effective_batch_size_k, output_num_heads, head_dim_qk
+#     )
+
+#     effective_batch_size_v = (
+#         output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
+#     )
+#     value_buffer_reshaped = _view_as_dense(
+#         v_t, effective_batch_size_v, output_num_heads, head_dim_v
+#     )
+
+#     if not q_batch_size_needs_broadcast:
+#         output_shape = q_t._size
+#         if head_dim_v != head_dim_qk:
+#             output_shape[-1] = head_dim_v
+#         if q_num_heads_needs_broadcast:
+#             output_shape[1] = output_num_heads
+#     else:
+#         output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
+#         output_shape[0] = q_t.size(1)
+#         output_shape[1] = output_num_heads
+#         output_shape[2] = head_dim_v
+
+#     return (
+#         query_buffer_reshaped,
+#         key_buffer_reshaped,
+#         value_buffer_reshaped,
+#         cumulative_sequence_length_q,
+#         cumulative_sequence_length_kv,
+#         max_seqlen_batch_q,
+#         max_seqlen_batch_kv,
+#         output_shape,
+#     )
+
+
+def _sdpa_nested_preprocessing(query, key, value):
+    # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
+    # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+    # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+    q_batch_size = query.size(0)
+    k_batch_size = key.size(0)
+    v_batch_size = value.size(0)
+
+    q_num_heads = query.size(1)
+    k_num_heads = key.size(1)
+    v_num_heads = value.size(1)
+
+    if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
+        q_num_heads == k_num_heads and k_num_heads == v_num_heads
+    ):
+        raise RuntimeError(
+            "This path is currently not implemented for jagged layout NT."
+        )
+        # return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
+
+    num_heads = query.size(1)
+    head_dim_qk = query.size(3)
+    head_dim_v = value.size(3)
+    q_t = query.transpose(1, 2)
+    k_t = key.transpose(1, 2)
+    v_t = value.transpose(1, 2)
+
+    (
+        cumulative_sequence_length_q,
+        max_seqlen_batch_q,
+        Nnz_q,
+    ) = _cumulative_and_max_seq_len_nnz(q_t)
+    (
+        cumulative_sequence_length_kv,
+        max_seqlen_batch_kv,
+        Nnz_kv,
+    ) = _cumulative_and_max_seq_len_nnz(k_t)
+
+    # [TODO] K and V have to have the same Nnz, should probably torch_check
+    # assume in order to not iterate over v
+
+    # If the physical layout of the NestedTensor's storage
+    # is not: batch, {seq_len}, num_heads, head_dim then we need
+    # to call contiguous
+    if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
+        q_t = q_t.contiguous()
+    if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
+        k_t = k_t.contiguous()
+    if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
+        v_t = v_t.contiguous()
+
+    query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
+    key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
+    value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
+
+    output_nt_info = {
+        "offsets": q_t.offsets(),
+        "lengths": q_t.lengths(),
+        "max_seqlen": q_t._get_max_seqlen(),
+        "min_seqlen": q_t._get_min_seqlen(),
+    }
+
+    return (
+        query_buffer_reshaped,
+        key_buffer_reshaped,
+        value_buffer_reshaped,
+        cumulative_sequence_length_q,
+        cumulative_sequence_length_kv,
+        max_seqlen_batch_q,
+        max_seqlen_batch_kv,
+        output_nt_info,
+    )
+
+
+def _pad_last_dim(
+    tensor: torch.Tensor, alignment_size: int, slice: bool
+) -> torch.Tensor:
+    # FlashAttentionV2 requires that head dimension be a multiple of 8
+    # This was previously done within the kernel, however
+    # This causes the kernel to maybe alias query, key, value
+    # So instead we pad the head_dimensions to be a multiple of 8
+    # in the composite region
+    last_dim_size = tensor.size(-1)
+    if last_dim_size % alignment_size == 0:
+        return tensor
+    pad_count = alignment_size - (last_dim_size % alignment_size)
+    tensor = torch.nn.functional.pad(tensor, [0, pad_count])
+    if slice:
+        return tensor[..., 0:last_dim_size]
+    return tensor
+
+
+# TODO: coalesce with torch/nn/utils/attention.py
+def _calculate_scale(query, scale):
+    # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
+    softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
+    return softmax_scale
+
+
+def _post_process_flash_output(out: torch.Tensor, og_size):
+    if not out.is_nested and out.size(-1) != og_size:
+        out = out[..., 0:og_size]
+    return out
+
+
+def _is_computing_meta_flops(x):
+    # Note: there's a use case of using meta tensors & the dispatch-based flop counter.
+    # We can use this function to check for this scenario in order to handle it specially.
+    if not torch.jit.is_scripting() and x.device.type == "meta":
+        torch_dispatch_mode_stack = (
+            torch.utils._python_dispatch._get_current_dispatch_mode_stack()
+        )
+        return any(
+            type(x) is torch.utils.flop_counter._FlopCounterMode
+            for x in torch_dispatch_mode_stack
+        )
+    return False
+
+
+def _autocast(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: torch.Tensor | None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
+    """
+    [Autocasting SDPA for NJT]
+
+    Normal autocasting doesn't work for NJT+SDPA right now:
+    * NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens
+      before we get to any aten ops or dispatcher logic; then the torch_function logic calls into
+      efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention
+      op won't work because we never see that aten op.
+    * If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but
+      the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention)
+      won't work correctly: the kernel selection logic will run before autocasting, and choose
+      a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual
+      attention computation will happen in a different dtype.
+
+    An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware
+    and rely on autocasting to do the actual conversions for flash attention / efficient attention.
+    However, by manually doing the actual autocast before the backend selection, we ensure that the
+    autocast handling for backend selection doesn't diverge from the autocast handling for the
+    actual dtype conversions.
+    """
+    device_type = query.device.type
+    # meta device is not supported by autocast, so break early for it
+    if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
+        return query, key, value, attn_mask
+
+    def cvt(x):
+        if x is None:
+            return x
+        target_dtype = torch.get_autocast_dtype(device_type)
+        if (
+            (not x.dtype.is_floating_point)
+            or x.dtype == target_dtype
+            or x.dtype == torch.float64
+        ):
+            return x
+        return x.to(target_dtype)
+
+    return cvt(query), cvt(key), cvt(value), cvt(attn_mask)
+
+
+def jagged_scaled_dot_product_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: torch.Tensor | None = None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    enable_gqa=False,
+):
+    query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
+    _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
+    # for mypy, ugh
+    assert (
+        isinstance(query, NestedTensor)
+        and isinstance(key, NestedTensor)
+        and isinstance(value, NestedTensor)
+    )
+    from torch.nested._internal.nested_tensor import (
+        nested_view_from_values_offsets_lengths,
+    )
+
+    # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
+    # second batch dim instead). For this case, we can just send the dense buffers through
+    # vanilla SDPA.
+    if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
+        output = F.scaled_dot_product_attention(
+            query.values(),
+            key.values(),
+            value.values(),
+            attn_mask=(
+                attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
+            ),
+            dropout_p=dropout_p,
+            is_causal=is_causal,
+            scale=scale,
+        )
+        return nested_view_from_values_offsets_lengths(
+            output,
+            query.offsets(),
+            query.lengths(),
+            min_seqlen=query._maybe_min_seqlen,  # type: ignore[attr-defined]
+            max_seqlen=query._maybe_max_seqlen,  # type: ignore[attr-defined]
+        )
+
+    compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
+
+    backend_choice = _select_sdp_backend(
+        query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
+    )
+
+    if _is_computing_meta_flops(query):
+        # Backend choice will probably not be correct if we have a meta device,
+        # because backend choice is device-aware. In this case, we mostly just
+        # want to avoid using math backend (which does a .item() call).
+        # Arbitrarily choose flash attention.
+        backend_choice = SDPBackend.FLASH_ATTENTION
+
+    if backend_choice == SDPBackend.FLASH_ATTENTION:
+        og_size = query.size(-1)
+        query_padded = _pad_last_dim(query, 8, False)
+        key_padded = _pad_last_dim(key, 8, False)
+        value_padded = _pad_last_dim(value, 8, False)
+        # We need to calculate the scale based off the OG head dim size
+        og_scale = _calculate_scale(query, scale)
+        (
+            query_buffer_reshaped,
+            key_buffer_reshaped,
+            value_buffer_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
+        (
+            attention,
+            _logsumexp,
+            _philox_seed,
+            _philox_offset,
+            _debug_attn_mask,
+        ) = torch.ops.aten._flash_attention_forward(
+            query_buffer_reshaped,
+            key_buffer_reshaped,
+            value_buffer_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            dropout_p,
+            is_causal,
+            False,
+            scale=og_scale,
+        )
+        # Reshape output to convert nnz to batch_size and seq_len
+        attention = nested_view_from_values_offsets_lengths(
+            attention,  # output from flash_attn is [total_q, num_heads, head_size_og]
+            **output_nt_info,
+        ).transpose(1, 2)
+        return _post_process_flash_output(attention, og_size)
+    elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
+        (
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query, key, value)
+        (
+            attention,
+            log_sumexp,
+            seed,
+            offset,
+            max_seqlen_q,
+            max_seqlen_batch_kv,
+        ) = torch.ops.aten._efficient_attention_forward(
+            query_reshaped.unsqueeze(0),
+            key_reshaped.unsqueeze(0),
+            value_reshaped.unsqueeze(0),
+            None,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            dropout_p,
+            int(is_causal),
+            compute_logsumexp,
+            scale=scale,
+        )
+        # Reshape output to convert nnz to batch_size and seq_len
+        return nested_view_from_values_offsets_lengths(
+            attention.squeeze(0),
+            **output_nt_info,
+        ).transpose(1, 2)
+    elif backend_choice == SDPBackend.CUDNN_ATTENTION:
+        (
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query, key, value)
+        (
+            attention,
+            logsumexp,
+            cum_seqlen_q,
+            cum_seqlen_kv,
+            max_seqlen_q,
+            max_seqlen_kv,
+            seed,
+            offset,
+            _,
+        ) = torch.ops.aten._cudnn_attention_forward(
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            attn_mask,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            compute_logsumexp,
+            dropout_p,
+            is_causal,
+            False,
+            scale=scale,
+        )
+        return nested_view_from_values_offsets_lengths(
+            attention,
+            **output_nt_info,
+        ).transpose(1, 2)
+    elif backend_choice == SDPBackend.MATH:
+        # save the offsets and shape of the inputs, so we can reshape the final output
+        # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
+        # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
+        offsets = query.offsets()
+        q_lengths = query.lengths()
+        min_seqlen = query._maybe_min_seqlen
+        max_seqlen = query._maybe_max_seqlen
+        d1 = query._size[1]
+        d2 = value._size[-1]
+
+        # convert jagged layout Nested Tensor to strided layout Nested Tensor
+        # which support the math implementation of SDPA
+        def get_strided_layout_nested_tensor(jagged_layout_nt):
+            lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
+            transpose = torch.transpose(jagged_layout_nt, 1, 2)
+            tensor_list = transpose.values().split(list(lengths), dim=0)
+            strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
+            strided_nt = strided_nt.transpose(1, 2).contiguous()
+            return strided_nt
+
+        query = get_strided_layout_nested_tensor(query)
+        key = get_strided_layout_nested_tensor(key)
+        value = get_strided_layout_nested_tensor(value)
+
+        attn_out = torch._scaled_dot_product_attention_math(
+            query, key, value, attn_mask, dropout_p, is_causal, scale=scale
+        )[0]
+
+        # convert strided layout Nested Tensor back to jagged layout Nested Tensor
+        attn_out = attn_out.transpose(1, 2).contiguous().values()
+        attn_out = attn_out.view(-1, d1, d2)
+        attn_out = nested_view_from_values_offsets_lengths(
+            attn_out,
+            offsets,
+            lengths=q_lengths,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        ).transpose(1, 2)
+
+        return attn_out
+    else:
+        raise RuntimeError(
+            "No viable backend for scaled_dot_product_attention was found."
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/attention/experimental/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/attention/experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a6694bbe3990bacac6025e8c8bd4ab86e80d2e9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/attention/experimental/__init__.py
@@ -0,0 +1,2 @@
+# Experimental features are not mature yet and are subject to change.
+# We do not provide any BC/FC guarantees
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/attention/experimental/_paged_attention.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/attention/experimental/_paged_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bbbc2b78aa6ab54983965458d1901dc4e1a1bb1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/attention/experimental/_paged_attention.py
@@ -0,0 +1,354 @@
+# mypy: allow-untyped-defs
+"""
+This module implements Paged Attention on top of flex_attention.
+This module is experimental and subject to change.
+"""
+
+import torch
+from torch.nn.attention.flex_attention import (
+    _identity,
+    _mask_mod_signature,
+    _score_mod_signature,
+    BlockMask,
+    noop_mask,
+)
+
+
+__all__ = ["PagedAttention"]
+
+
+def _cdiv(x: int | float | torch.Tensor, multiple: int | float | torch.Tensor):
+    return (x + multiple - 1) // multiple
+
+
+class PagedAttention:
+    """
+    PagedAttention supports flex attention inference with a large batch size.
+    With PagedAttention, a batch of key/value tensors with varying kv length
+    is split into tensor blocks of fixed length and cached in a compact way.
+    Thus we can avoid redundant memory consumption due to varying kv length and
+    support a larger batch size.
+    """
+
+    def __init__(
+        self,
+        n_pages: int,
+        page_size: int,
+        max_batch_size: int,
+        device: str = "cuda",
+    ) -> None:
+        # number of pages
+        self.n_pages = n_pages
+
+        # number of tokens per page
+        self.page_size = page_size
+
+        # page table: [batch, logical_block_idx] -> physical_page_idx
+        self.page_table = -torch.ones(
+            (max_batch_size, self.n_pages), dtype=torch.int64, device=device
+        )
+
+        # capacity: batch_idx -> allocated sequence length
+        self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
+
+        # index of empty pages that is available for allocation
+        self.empty_pages = list(range(n_pages - 1, -1, -1))
+
+        # mapping from physical page index to logical page index
+        self.physical_to_logical = -torch.ones(
+            (max_batch_size, n_pages), dtype=torch.int64, device=device
+        )
+
+    def reserve(self, batch_idx: torch.Tensor, seq_len: torch.Tensor) -> None:
+        """
+        Requests the capacity of a given batch to be at least enough to
+        hold `seq_len` elements.
+
+        Args:
+            batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`.
+            seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`.
+        """
+
+        if seq_len <= self.capacity[batch_idx]:
+            return
+
+        num_pages_to_allocate = _cdiv(
+            seq_len - self.capacity[batch_idx], self.page_size
+        )
+
+        assert len(self.empty_pages) >= num_pages_to_allocate, (
+            f"requested {num_pages_to_allocate.item()} pages "
+            f"but there are only {len(self.empty_pages)} empty pages"
+        )
+
+        start_page_idx = self.capacity[batch_idx] // self.page_size
+        end_page_idx = start_page_idx + num_pages_to_allocate
+
+        # find empty physical pages
+        allocated_pages = torch.tensor(
+            self.empty_pages[-num_pages_to_allocate:],
+            device=num_pages_to_allocate.device,
+        )
+        self.empty_pages = self.empty_pages[:-num_pages_to_allocate]
+
+        # update page table
+        self.page_table[
+            batch_idx,
+            start_page_idx:end_page_idx,
+        ] = allocated_pages
+
+        # update metadata
+        self.physical_to_logical[batch_idx, allocated_pages] = torch.arange(
+            start_page_idx.item(),
+            end_page_idx.item(),
+            device=num_pages_to_allocate.device,
+        )
+        self.capacity[batch_idx] += num_pages_to_allocate * self.page_size
+
+    def erase(self, batch_idx: torch.Tensor) -> None:
+        """
+        Removes a single batch from paged attention.
+
+        Args:
+            batch_idx (Tensor): batch index to be removed; shape :math:`(1)`.
+        """
+
+        # find allocated pages
+        allocated_page_idx = self.page_table[batch_idx] != -1
+        allocated_pages = self.page_table[batch_idx][allocated_page_idx]
+
+        # clean metadata
+        self.capacity[batch_idx] = 0
+        self.empty_pages += allocated_pages.tolist()
+        self.physical_to_logical[batch_idx][:, allocated_pages] = -1
+        self.page_table[batch_idx] = -1
+
+    def assign(
+        self,
+        batch_idx: torch.Tensor,
+        input_pos: torch.Tensor,
+        k_val: torch.Tensor,
+        v_val: torch.Tensor,
+        k_cache: torch.Tensor,
+        v_cache: torch.Tensor,
+    ) -> None:
+        """
+        Assigns new contents `val` to the storage `cache` at the location
+        `batch_idx` and `input_pos`.
+
+        Args:
+            batch_idx (Tensor): batch index; shape :math:`(B)`.
+            input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`.
+            val (Tensor): value to be assigned; shape :math:`(B, H, S, D)`
+            cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)`
+        """
+        if k_val.requires_grad:
+            raise RuntimeError("val must not require gradient")
+
+        B, H, S, K_D = k_val.shape
+        V_D = v_val.shape[3]
+        if B != batch_idx.shape[0]:
+            raise RuntimeError(
+                f"Expect val and batch_idx have the same batch size "
+                f"but got B={B} and B={batch_idx.shape[0]}."
+            )
+        if H != k_cache.shape[1]:
+            raise RuntimeError(
+                f"Expect val and cache has the same number of heads "
+                f"but got H={H} and H={k_cache.shape[1]}."
+            )
+        if S != input_pos.shape[1]:
+            raise RuntimeError(
+                f"Expect val and input_pos has the same length "
+                f"but got S={S} and S={input_pos.shape[0]}."
+            )
+        if K_D != k_cache.shape[3]:
+            raise RuntimeError(
+                f"Expect k_val and k_cache has the same hidden dim "
+                f"but got D={K_D} and D={k_cache.shape[3]}."
+            )
+        if V_D != v_cache.shape[3]:
+            raise RuntimeError(
+                f"Expect v_val and v_cache has the same hidden dim "
+                f"but got D={V_D} and D={v_cache.shape[3]}."
+            )
+
+        # find address
+        logical_block_idx = input_pos // self.page_size  # [B, S]
+        logical_block_offset = input_pos % self.page_size  # [B, S]
+        physical_block_idx = torch.gather(
+            self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)
+        ).to(torch.int32)  # [B, S]
+
+        addr = (physical_block_idx * self.page_size + logical_block_offset).view(
+            -1
+        )  # [B*S]
+
+        k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D)
+        v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D)
+
+        k_cache[:, :, addr, :] = k_val
+        v_cache[:, :, addr, :] = v_val
+
+    def convert_logical_block_mask(
+        self,
+        block_mask: BlockMask,
+        batch_idx: torch.Tensor | None = None,
+        kv_len: torch.Tensor | None = None,
+    ) -> BlockMask:
+        """
+        Converts a logical block mask by mapping its logical kv indices to the corresponding
+        physical kv indices.
+
+        Args:
+            block_mask (BlockMask): logical block mask;
+                kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`.
+            batch_idx (Tensor): batch index corresponding to the block_mask
+                batch dimension. This provides flexibility to convert a
+                block mask with smaller batch size than the page table;
+                shape :math:`(B)`.
+            kv_len (Optional[Tensor]): actual KV sequence length for upper bound check;
+                shape :math:`(B,)` to handle multiple batches.
+        """
+        B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape
+
+        if block_mask.BLOCK_SIZE[1] != self.page_size:
+            raise RuntimeError(
+                f"Expect block_mask has the same column block size as page_size"
+                f"but got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}"
+            )
+
+        # Increase the num columns of converted block mask from logical block mask's
+        # num columns to n_pages, since a) the converted block mask
+        # may have larger indices values; and b) `_ordered_to_dense` realizes
+        # a dense tensor with these converted indices. There would be an IndexError
+        # if using the logical block mask's num columns.
+
+        device = block_mask.kv_num_blocks.device
+
+        if batch_idx is None:
+            batch_idx = torch.arange(B, device=device)
+        page_table = self.page_table[batch_idx]
+
+        new_kv_num_blocks = block_mask.kv_num_blocks.clone()
+
+        new_kv_indices = torch.zeros(
+            (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device
+        )
+        new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
+            torch.gather(
+                page_table, 1, block_mask.kv_indices.view(B, -1).to(torch.int64)
+            )
+            .view(block_mask.kv_indices.shape)
+            .to(torch.int32)
+        )
+
+        new_full_kv_indices, new_full_kv_num_blocks = None, None
+        if block_mask.full_kv_num_blocks is not None:
+            assert block_mask.full_kv_indices is not None
+            new_full_kv_num_blocks = block_mask.full_kv_num_blocks.clone()
+            new_full_kv_indices = torch.zeros(
+                (B, H, ROWS, self.n_pages), dtype=torch.int32, device=device
+            )
+            new_full_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
+                torch.gather(
+                    page_table,
+                    1,
+                    block_mask.full_kv_indices.view(B, -1).to(torch.int64),
+                )
+                .view(block_mask.full_kv_indices.shape)
+                .to(torch.int32)
+            )
+
+        new_mask_mod = self.get_mask_mod(block_mask.mask_mod, kv_len)
+
+        seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
+        return BlockMask.from_kv_blocks(
+            new_kv_num_blocks,
+            new_kv_indices,
+            new_full_kv_num_blocks,
+            new_full_kv_indices,
+            block_mask.BLOCK_SIZE,
+            new_mask_mod,
+            seq_lengths=seq_lengths,
+        )
+
+    def get_mask_mod(
+        self,
+        mask_mod: _mask_mod_signature | None,
+        kv_len: torch.Tensor | None = None,
+    ) -> _mask_mod_signature:
+        """
+        Converts a mask_mod based on mapping from the physical block index to the logical
+        block index.
+
+        Args:
+            mask_mod (_mask_mod_signature): mask_mod based on the logical block index.
+            kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check.
+        """
+        if mask_mod is None:
+            mask_mod = noop_mask
+
+        def new_mask_mod(
+            b: torch.Tensor,
+            h: torch.Tensor,
+            q_idx: torch.Tensor,
+            physical_kv_idx: torch.Tensor,
+        ):
+            physical_kv_block = physical_kv_idx // self.page_size
+            physical_kv_offset = physical_kv_idx % self.page_size
+            logical_block_idx = self.physical_to_logical[b, physical_kv_block]
+            logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
+            live_block = logical_block_idx >= 0
+            within_upper_bound = (
+                logical_kv_idx < kv_len[b] if kv_len is not None else True
+            )
+            within_lower_bound = logical_kv_idx >= 0
+            is_valid = live_block & within_upper_bound & within_lower_bound
+
+            return torch.where(is_valid, mask_mod(b, h, q_idx, logical_kv_idx), False)
+
+        return new_mask_mod
+
+    def get_score_mod(
+        self,
+        score_mod: _score_mod_signature | None,
+        kv_len: torch.Tensor | None = None,
+    ) -> _score_mod_signature:
+        """
+        Converts a score_mod based on mapping from the physical block index to the logical
+        block index.
+
+        Args:
+            score_mod (_score_mod_signature): score_mod based on the logical block index.
+            `kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check.
+
+        """
+        if score_mod is None:
+            score_mod = _identity
+
+        def new_score_mod(
+            score: torch.Tensor,
+            b: torch.Tensor,
+            h: torch.Tensor,
+            q_idx: torch.Tensor,
+            physical_kv_idx: torch.Tensor,
+        ):
+            physical_kv_block = physical_kv_idx // self.page_size
+            physical_kv_offset = physical_kv_idx % self.page_size
+            logical_block_idx = self.physical_to_logical[b, physical_kv_block]
+            logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
+            live_block = logical_block_idx >= 0
+            within_upper_bound = (
+                logical_kv_idx < kv_len[b] if kv_len is not None else True
+            )
+            within_lower_bound = logical_kv_idx >= 0
+            is_valid = live_block & within_upper_bound & within_lower_bound
+
+            return torch.where(
+                is_valid,
+                score_mod(score, b, h, q_idx, logical_kv_idx),
+                float("-inf"),
+            )
+
+        return new_score_mod
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a836cb4025b8157036d862fdda7a81568572833
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/_functions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1c33856976c135e78663933925140601a0dcf10
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/_functions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/activation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/activation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de0e5a29a7be4021c33e45b279af5f3904892145
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/activation.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08b1985885016baf74ea3ba0f05376840937bd5d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8be33b41a2acd56dfd69c629d6dc4e4ad7d0763
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21bfdafd07139404dff3ae6bc17e62e3b9be62a2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/container.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/container.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc81d2f69174b44ce204f017a79eeae962c4d929
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/container.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/conv.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/conv.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9485c790dc7fb6bd0c313ff43b5f9509afc73712
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/conv.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/distance.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/distance.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..242d191e1987e35f91acaaed0131d60064c2d463
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/distance.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/dropout.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/dropout.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46afaa2de4cf271b64f69da4fbe2ec4cdcf02529
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/dropout.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/flatten.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/flatten.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07cea41d2ae7d34b4e0842613cc0b7ba979f9998
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/flatten.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/fold.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/fold.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7a8cdd4b29da41b5177e958a773060ca139264a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/fold.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..441139bc7fbf8c539da39491d7b2062cce957d67
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/lazy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/lazy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..036573573b772b6d02c249095e5e4acec3fc48f3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/lazy.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/linear.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/linear.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dc1edecccbb835bec127d6d31f1c25c3ff88d8f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/linear.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/normalization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/normalization.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04e2c0dcde59b5872cd0984b5bf2a13ab6902813
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/normalization.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/padding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/padding.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb795716eba24aebf23fcb04cce7a638836e7f7a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/padding.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34531ba6037395f11e2879a208a134df884e72b2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pooling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pooling.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e773bed71bbac1901df43e25527b8bf08fd973f7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/pooling.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/rnn.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/rnn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02897a19704daa02b17189a1f844babf004bc833
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/rnn.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/sparse.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/sparse.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5354800f42af9923d651fa62bd6313ef486d55f7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/sparse.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/transformer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/transformer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4893c2148b05ac0c4bdaa989009ae81ea48b7374
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/transformer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ff7c893a817df9cfca9a9f3714bf633c9dfb309
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e53d601f1553fefe800c53d0de195766f0f70f1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/modules/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eab4ffc0fb21037dce6a713742098ba1258b1c02
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f99cc94d06f5ff3ac396b89c9d9a4168abdad5d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c5f8fdd8c36cafbb572b39391a6b7538fcd9441
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68512286f946852436a6e6a70d725340a9350336
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4b950f0d3c1ed6e735273a8168528ff99e47bf9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5bd5077db0e94b2d06555b2d27d453a48aef415
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/qat/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/qat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..766b09382aa78e65aba915e4e6faf7979c500d1b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/qat/__init__.py
@@ -0,0 +1,19 @@
+# flake8: noqa: F401
+r"""QAT Dynamic Modules.
+
+This package is in the process of being deprecated.
+Please, use `torch.ao.nn.qat.dynamic` instead.
+"""
+
+from torch.nn.qat import dynamic, modules  # noqa: F403
+from torch.nn.qat.modules import *  # noqa: F403
+
+
+__all__ = [
+    "Linear",
+    "Conv1d",
+    "Conv2d",
+    "Conv3d",
+    "Embedding",
+    "EmbeddingBag",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/_reference/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/_reference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61faa90bd95cc7e255be2df82c617b5bab46b044
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/_reference/__init__.py
@@ -0,0 +1 @@
+from torch.nn.quantized._reference.modules import *  # noqa: F403
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b08cd1bc7149c5506db3a952fff488eb06749f5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/dynamic/__init__.py
@@ -0,0 +1 @@
+from torch.ao.nn.quantized.dynamic import *  # noqa: F403
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae76d1968b0faaf30f861ab009b9011ce2960cc5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/__init__.py
@@ -0,0 +1,97 @@
+r"""Quantized Modules.
+
+Note::
+    The `torch.nn.quantized` namespace is in the process of being deprecated.
+    Please, use `torch.ao.nn.quantized` instead.
+"""
+
+# The following imports are needed in case the user decides
+# to import the files directly,
+# s.a. `from torch.nn.quantized.modules.conv import ...`.
+# No need to add them to the `__all__`.
+from torch.ao.nn.quantized.modules import (
+    activation,
+    batchnorm,
+    conv,
+    DeQuantize,
+    dropout,
+    embedding_ops,
+    functional_modules,
+    linear,
+    MaxPool2d,
+    normalization,
+    Quantize,
+    rnn,
+    utils,
+)
+from torch.ao.nn.quantized.modules.activation import (
+    ELU,
+    Hardswish,
+    LeakyReLU,
+    MultiheadAttention,
+    PReLU,
+    ReLU6,
+    Sigmoid,
+    Softmax,
+)
+from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
+from torch.ao.nn.quantized.modules.conv import (
+    Conv1d,
+    Conv2d,
+    Conv3d,
+    ConvTranspose1d,
+    ConvTranspose2d,
+    ConvTranspose3d,
+)
+from torch.ao.nn.quantized.modules.dropout import Dropout
+from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
+from torch.ao.nn.quantized.modules.functional_modules import (
+    FloatFunctional,
+    FXFloatFunctional,
+    QFunctional,
+)
+from torch.ao.nn.quantized.modules.linear import Linear
+from torch.ao.nn.quantized.modules.normalization import (
+    GroupNorm,
+    InstanceNorm1d,
+    InstanceNorm2d,
+    InstanceNorm3d,
+    LayerNorm,
+)
+from torch.ao.nn.quantized.modules.rnn import LSTM
+
+
+__all__ = [
+    "BatchNorm2d",
+    "BatchNorm3d",
+    "Conv1d",
+    "Conv2d",
+    "Conv3d",
+    "ConvTranspose1d",
+    "ConvTranspose2d",
+    "ConvTranspose3d",
+    "DeQuantize",
+    "ELU",
+    "Embedding",
+    "EmbeddingBag",
+    "GroupNorm",
+    "Hardswish",
+    "InstanceNorm1d",
+    "InstanceNorm2d",
+    "InstanceNorm3d",
+    "LayerNorm",
+    "LeakyReLU",
+    "Linear",
+    "LSTM",
+    "MultiheadAttention",
+    "Quantize",
+    "ReLU6",
+    "Sigmoid",
+    "Softmax",
+    "Dropout",
+    "PReLU",
+    # Wrapper modules
+    "FloatFunctional",
+    "FXFloatFunctional",
+    "QFunctional",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/activation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d85162ef35c7cd6a399f2a73a9a6b8f3c1154cd9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/activation.py
@@ -0,0 +1,20 @@
+# flake8: noqa: F401
+r"""Quantized Modules.
+
+This file is in the process of migration to `torch/ao/nn/quantized`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/quantized/modules`,
+while adding an import statement here.
+"""
+
+from torch.ao.nn.quantized.modules.activation import (
+    ELU,
+    Hardswish,
+    LeakyReLU,
+    MultiheadAttention,
+    PReLU,
+    ReLU6,
+    Sigmoid,
+    Softmax,
+)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/embedding_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/embedding_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25f8bea7e378023a8eb3ece75a5fb9a23163529
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/embedding_ops.py
@@ -0,0 +1,18 @@
+# flake8: noqa: F401
+r"""Quantized Modules.
+
+This file is in the process of migration to `torch/ao/nn/quantized`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/quantized/modules`,
+while adding an import statement here.
+"""
+
+from torch.ao.nn.quantized.modules.embedding_ops import (
+    Embedding,
+    EmbeddingBag,
+    EmbeddingPackedParams,
+)
+
+
+__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/linear.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9ba5a5c12f82915db53d81a7b9e5a1c0e530e98
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/linear.py
@@ -0,0 +1,14 @@
+# flake8: noqa: F401
+r"""Quantized Modules.
+
+This file is in the process of migration to `torch/ao/nn/quantized`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/quantized/modules`,
+while adding an import statement here.
+"""
+
+from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams
+
+
+__all__ = ["LinearPackedParams", "Linear"]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea333af04ca49138a3b3ed35020654d4dad5ffe9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/quantized/modules/utils.py
@@ -0,0 +1,17 @@
+# flake8: noqa: F401
+r"""Quantized Modules.
+
+This file is in the process of migration to `torch/ao/nn/quantized`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+appropriate file under the `torch/ao/nn/quantized/modules`,
+while adding an import statement here.
+"""
+
+from torch.ao.nn.quantized.modules.utils import (
+    _hide_packed_params_repr,
+    _ntuple_from_first,
+    _pair_from_first,
+    _quantize_weight,
+    WeightedQuantizedModule,
+)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c149074e60b11f718ae35205969a5f78807af24
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d1153d9022a4b2f24b5b7686620b85d3ceec460
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4de98a7b474fabf1aa5f9e63bb981cf1e18d92a6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7673f48fd3f1d38840126f2c4a177c99faf85006
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..119d496bfe6dcbb9d0f1c414ee791cb10eccc131
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..502f2bf4b28bd40fbd15c92e76430f47b27b00cc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/fusion.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/fusion.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b40a68cfdf483a3cbc8bf99cc1044c5bccefd0a0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/fusion.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/init.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/init.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62b8f5e33da122b9d09adb8211aba4d7c45f6cec
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/init.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4777033d3276072e8c7fc9e56957cdd8b30a0d56
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f0950c3f816ffd0988ce243563606279e849c54
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fedd95978599d2b1a2a545e1d14139906d12ea28
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/prune.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/prune.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5488e20ab36b8258f5d981337a17c036639a67e1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/prune.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/rnn.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/rnn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..182ec95d7fb826f616d57a110d1c26e8731e9c13
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/rnn.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9008d01767e4d757652ff3eb4820d31b23774ddb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/stateless.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/stateless.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5dfad81ba6040ebf2ede2fc79762ad4b28c669f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/stateless.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea5abbeb316cf482722fb7eb8899f56661e6808d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55bc55ddcb00c3fbaf85769230702488f7cf954d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60a536633c494b8dede5996733df0e9d6c4a7f7b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/core.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..445decb0f169f20d6b75232fef711857a1bae83e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/core.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/hop.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/hop.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08ba3b2d1c1e330f8ccaea52a05ec1a6fc98e336
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/hop.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5d73b84cc4375ff611d36182fa27d4f6bf7bc2b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/_pass.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e34c27360d382b445a813faf2b2cbb96bd8ce441
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/_pass.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/type_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/type_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6aceaadf23f6c1e23895213d52f1b41fd45ec581
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/__pycache__/type_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff83563a5a0852f68db92c5059d6eb2f17faa72
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__init__.py
@@ -0,0 +1,6 @@
+from .type_promotion import InsertTypePromotion
+
+
+__all__ = [
+    "InsertTypePromotion",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6ca353466562e1bf6af7da9d0eedb9d08c81167
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__pycache__/type_promotion.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__pycache__/type_promotion.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f67b68c7fe39ac74e3a94198b4cfa252a5ba751a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/__pycache__/type_promotion.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d4e919a3b2fb409333ed182c183107a8f510931
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py
@@ -0,0 +1,1670 @@
+# mypy: allow-untyped-defs
+# Owner(s): ["module: onnx"]
+from __future__ import annotations
+
+import abc
+import dataclasses
+import inspect
+import logging
+from typing import Any, TYPE_CHECKING
+
+import torch
+import torch._dispatch.python
+import torch._ops
+import torch.fx
+import torch.fx.traceback as fx_traceback
+from torch import _prims_common, _refs
+from torch._prims_common import (
+    ELEMENTWISE_TYPE_PROMOTION_KIND,
+    wrappers as _prims_common_wrappers,
+)
+from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
+from torch._refs.nn import functional as _functional_refs
+from torch.fx.experimental import proxy_tensor
+from torch.onnx._internal.fx import _pass, type_utils as fx_type_utils
+from torch.utils import _python_dispatch, _pytree
+
+
+if TYPE_CHECKING:
+    from collections.abc import Callable, Mapping, Sequence
+    from types import ModuleType
+
+    from torch._subclasses import fake_tensor
+
+
+logger = logging.getLogger(__name__)
+
+
+def _try_getclosurevars(func):
+    try:
+        return inspect.getclosurevars(func)
+    except TypeError:
+        return None
+
+
+@dataclasses.dataclass
+class TypePromotionSnapshot:
+    """Type promotion snapshot for a fx node and its inputs.
+
+    Contains the promoted dtype for args and kwargs that needs promoting.
+    Contains the expected node output dtype.
+    """
+
+    args_dtypes: Mapping[int, torch.dtype]
+    """Mapping from arg position to dtype to promote to."""
+
+    kwargs_dtypes: Mapping[str, torch.dtype]
+    """Mapping from kwarg name to dtype to promote to."""
+
+    out_dtype: torch.dtype
+    """Expected output dtype of the node."""
+
+
+class TypePromotionRule(abc.ABC):
+    """Base class for type promotion rule per 'torch.ops.{namespace}.{op_name}'."""
+
+    def __init__(self, namespace: str, op_name: str) -> None:
+        self.namespace = namespace
+        self.op_name = op_name
+
+    # Make this abstract as well because subclass needs to override __eq__().
+    # A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None.
+    # Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__
+    @abc.abstractmethod
+    def __hash__(self) -> int: ...
+
+    @abc.abstractmethod
+    def __repr__(self) -> str: ...
+
+    @abc.abstractmethod
+    def __eq__(self, other: object) -> bool: ...
+
+    def is_valid(self) -> bool:
+        """Check if the rule is valid."""
+        # This always returns a module. If the module does not exist it will be created.
+        module = getattr(torch.ops, self.namespace)
+        py_op = getattr(module, self.op_name, None)
+        if py_op is None:
+            logger.warning(
+                "Cannot find op: %s in module: %s", self.op_name, self.namespace
+            )
+            return False
+        if not isinstance(py_op, torch._ops.OpOverloadPacket):
+            logger.warning(
+                "Op: torch.ops.%s.%s is not an OpOverloadPacket, got: %s",
+                self.namespace,
+                self.op_name,
+                type(py_op),
+            )
+            return False
+
+        return True
+
+    @abc.abstractmethod
+    def preview_type_promotion(
+        self, args: tuple, kwargs: dict
+    ) -> TypePromotionSnapshot:
+        """Preview type promotion results for provided set of args and kwargs.
+
+        Returns a TypePromotionSnapshot object that contains the promoted dtypes for
+        the arguments and the expected output dtype.
+        """
+        ...
+
+
+class ElementwiseTypePromotionRule(TypePromotionRule):
+    """Defines how to perform elementwise type promotion for 'torch.ops.{namespace}.{op_name}'."""
+
+    _USE_OPMATH: bool = False
+    """Whether to use opmath to compute the promoted input dtype.
+    If used, upcasts will be inserted everywhere for lower precision models.
+    Set to False and have torchlib handle upcasts in op implementation internally.
+    """
+
+    def __init__(
+        self,
+        namespace: str,
+        op_name: str,
+        promote_args_positions: Sequence[int],
+        promote_kwargs_names: Sequence[str],
+        promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND,
+    ) -> None:
+        """Constructs a TypePromotionRule for elementwise operators.
+
+        Args:
+            namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.add'.
+            op_name: Name of the op. E.g. 'add' in 'torch.ops.aten.add'.
+            promote_args_positions: Positions of args to promote.
+            promote_kwargs_names: Names of kwargs to promote.
+            promotion_kind: Type promotion kind. Refer to [_prims_common.elementwise_dtypes](https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py) for detail.  # noqa: B950
+        """
+        super().__init__(namespace, op_name)
+        self.promote_args_positions = promote_args_positions
+        self.promote_kwargs_names = promote_kwargs_names
+        self.promotion_kind = promotion_kind
+
+    def __repr__(self) -> str:
+        return (
+            f"ElementwiseTypePromotionRule('{self.namespace}', '{self.op_name}', "
+            f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})"
+        )
+
+    # pyrefly: ignore [bad-override]
+    def __eq__(self, other: object, /) -> bool:
+        if not isinstance(other, ElementwiseTypePromotionRule):
+            return False
+        return (
+            self.namespace == other.namespace
+            and self.op_name == other.op_name
+            and self.promote_args_positions == other.promote_args_positions
+            and self.promote_kwargs_names == other.promote_kwargs_names
+            and self.promotion_kind == other.promotion_kind
+        )
+
+    def __hash__(self) -> int:
+        return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__()
+
+    def _consolidate_input_dtype(
+        self, computed_dtype: torch.dtype, result_dtype: torch.dtype
+    ) -> torch.dtype:
+        """
+        Although opmath is the right thing to do to retain on-par precision, it inserts
+        upcasts everywhere in the graph. This is particularly hard for backend to optimize
+        since there is no way to differentiate between inserted upcasts and model code
+        casts. Hence we consolidate the input dtype to the result dtype to avoid this.
+        """
+        if not self._USE_OPMATH and self.promotion_kind in (
+            _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+            _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        ):
+            return result_dtype
+        return computed_dtype
+
+    def preview_type_promotion(
+        self, args: tuple, kwargs: dict
+    ) -> TypePromotionSnapshot:
+        candidate_args = {
+            i: args[i]
+            for i in self.promote_args_positions
+            if i < len(args) and args[i] is not None
+        }
+        candidate_kwargs = {
+            name: kwargs[name]
+            for name in self.promote_kwargs_names
+            if name in kwargs and kwargs[name] is not None
+        }
+
+        computed_dtype, result_dtype = _prims_common.elementwise_dtypes(
+            *_pytree.arg_tree_leaves(*candidate_args.values(), **candidate_kwargs),
+            type_promotion_kind=self.promotion_kind,
+        )
+
+        consolidated_input_dtype = self._consolidate_input_dtype(
+            computed_dtype, result_dtype
+        )
+
+        return TypePromotionSnapshot(
+            dict.fromkeys(candidate_args.keys(), consolidated_input_dtype),
+            dict.fromkeys(candidate_kwargs.keys(), consolidated_input_dtype),
+            result_dtype,
+        )
+
+
+class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule):
+    """Reference type promotion rule from torch._refs.div.
+
+    Rule depends on the value of the `rounding_mode` argument.
+    """
+
+    def __init__(self) -> None:
+        super().__init__(
+            "aten",
+            "div",
+            promote_args_positions=(0, 1),
+            promote_kwargs_names=(),
+            promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+        )
+
+    def preview_type_promotion(
+        self, args: tuple, kwargs: dict
+    ) -> TypePromotionSnapshot:
+        rounding_mode = kwargs.get("rounding_mode")
+        if rounding_mode is None:
+            # true_divide
+            self.promotion_kind = (
+                _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+            )
+            return super().preview_type_promotion(args, kwargs)
+        if rounding_mode == "trunc":
+            # trunc_divide
+            self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+            return super().preview_type_promotion(args, kwargs)
+        if rounding_mode == "floor":
+            # floor_divide
+            self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+            return super().preview_type_promotion(args, kwargs)
+        raise ValueError(f"Unknown rounding_mode: {rounding_mode}")
+
+
+class ReductionTypePromotionRule(TypePromotionRule):
+    def __init__(
+        self,
+        namespace: str,
+        op_name: str,
+        promotion_kind: _prims_common.REDUCTION_OUTPUT_TYPE_KIND,
+    ) -> None:
+        """Constructs a TypePromotionRule for reduction operators.
+
+        Args:
+            namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.sum'.
+            op_name: Name of the op. E.g. 'sum' in 'torch.ops.aten.sum'.
+            promotion_kind: Type promotion kind. Refer to [_prims_common.reduction_dtypes]((https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py)) for detail.  # noqa: B950
+        """
+        super().__init__(namespace, op_name)
+        self.promotion_kind = promotion_kind
+
+    def __repr__(self) -> str:
+        return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})"
+
+    # pyrefly: ignore [bad-override]
+    def __eq__(self, other: object, /) -> bool:
+        if not isinstance(other, ElementwiseTypePromotionRule):
+            return False
+        return (
+            self.namespace == other.namespace
+            and self.op_name == other.op_name
+            and self.promotion_kind == other.promotion_kind
+        )
+
+    def __hash__(self) -> int:
+        return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__()
+
+    def preview_type_promotion(
+        self, args: tuple, kwargs: dict
+    ) -> TypePromotionSnapshot:
+        assert len(args) >= 1, (
+            f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
+        )
+        arg = args[0]
+        assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
+        dtype: torch.dtype | None = kwargs.get("dtype")
+
+        computation_dtype, result_dtype = _prims_common.reduction_dtypes(
+            arg, self.promotion_kind, dtype
+        )
+        if result_dtype is None:
+            # Inspecting code, this can only happen when `promotion_kind` is `KEEP_PROMOTED_TYPE`.
+            # Hence set same as computation_dtype.
+            result_dtype = computation_dtype
+
+        return TypePromotionSnapshot(
+            {0: computation_dtype},
+            {},
+            result_dtype,
+        )
+
+
+class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule):
+    """Reference type promotion rule from torch.ops.aten.all or torch.ops.aten.any.
+
+    This is a special case where computation dtype is always torch.bool.
+    The result dtype is always uint8 if `dtype` kwarg is uint8, otherwise torch.bool.
+    """
+
+    def __init__(self, op_name: str) -> None:
+        super().__init__(
+            "aten",
+            op_name,
+            _prims_common.REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL,
+        )
+
+    def preview_type_promotion(
+        self, args: tuple, kwargs: dict
+    ) -> TypePromotionSnapshot:
+        assert len(args) >= 1, (
+            f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
+        )
+        arg = args[0]
+        assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
+        computation_dtype = torch.bool
+        # Preserves uint8 -- probably a legacy mask thing
+        result_dtype = torch.uint8 if arg.dtype == torch.uint8 else torch.bool
+        return TypePromotionSnapshot(
+            {0: computation_dtype},
+            {},
+            result_dtype,
+        )
+
+
+class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule):
+    """Reference type promotion rule from torch.ops.aten.sum.
+
+    This is a special case where computation dtype is always torch.int64 for integral arg,
+    unless overridden by `dtype` kwarg.
+    """
+
+    def preview_type_promotion(
+        self, args: tuple, kwargs: dict
+    ) -> TypePromotionSnapshot:
+        assert len(args) >= 1, (
+            f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
+        )
+        arg = args[0]
+        assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
+        dtype: torch.dtype | None = kwargs.get("dtype")
+        # The below logic is copied from `torch/_refs/__init__.py` reduction ops impl.
+        if dtype is None:
+            if _prims_common.is_boolean_dtype(
+                arg.dtype
+            ) or _prims_common.is_integer_dtype(arg.dtype):
+                dtype = torch.int64
+            else:
+                dtype = arg.dtype
+        return super().preview_type_promotion(args, {"dtype": dtype})
+
+
+# NOTE: [Update type promotion rule]
+# BELOW TABLE IS GENERATED FROM `TypePromotionRuleSetGenerator.generate_from_torch_refs`.
+# DO NOT EDIT MANUALLY !!!
+# For missing rules or discrepancies, please
+# 1. Run `pytest test/onnx/test_fx_type_promotion.py` to validate if the generated rule set is current.
+#    If it is not, update with new generated set.
+# 2. If discrepancies still exist, consider debugging torch._refs or report a bug.
+# 3. If rules are still missing, add them to `_EXTRA_TYPE_PROMOTION_RULE_SET` or report a bug.
+# Check `TypePromotionRule` class for how each rule is defined and used.
+_GENERATED_ATEN_TYPE_PROMOTION_RULE_SET = {
+    ElementwiseTypePromotionRule(
+        "aten", "abs", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "abs_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "acos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "acos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "acosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "acosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "add", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "add_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "addcdiv", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "addcdiv_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "addcmul", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "addcmul_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "addr", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "asin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "asin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "asinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "asinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "atan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "atan2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "atan2_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "atan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "atanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "atanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten",
+        "bitwise_left_shift",
+        [0, 1],
+        [],
+        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+    ),
+    ElementwiseTypePromotionRule(
+        "aten",
+        "bitwise_left_shift_",
+        [0, 1],
+        [],
+        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten",
+        "bitwise_right_shift",
+        [0, 1],
+        [],
+        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+    ),
+    ElementwiseTypePromotionRule(
+        "aten",
+        "bitwise_right_shift_",
+        [0, 1],
+        [],
+        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "bitwise_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "cat", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "cauchy", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "cauchy_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "ceil", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "ceil_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "celu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "celu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "clamp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "clamp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "copysign", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "copysign_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "cos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "cos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "cosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "cosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "deg2rad", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "deg2rad_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "digamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "dot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "elu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "eq", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "eq_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "erf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "erf_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "erfc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "erfc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "erfinv", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "erfinv_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "exp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "exp2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "exp2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "exp_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "expm1", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "expm1_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "exponential", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "exponential_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "fill", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "floor", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "floor_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "floor_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "floor_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "fmax", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "fmin", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "fmod", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "fmod_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "frac", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "frac_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "gcd", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "gcd_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "ge", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "ge_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "gelu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "geometric", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "geometric_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "glu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "gt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "gt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "hardtanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "heaviside", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "heaviside_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "huber_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "hypot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "hypot_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "i0", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "i0_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "igamma", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "igamma_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "igammac", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "igammac_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "isfinite", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "isinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "isnan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "isneginf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "isposinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "isreal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "l1_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lcm", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lcm_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "le", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "le_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "leaky_relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lerp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lerp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lgamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lgamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log10", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log10_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log1p", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log1p_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log_normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "log_normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logaddexp", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logaddexp2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logical_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "logsumexp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "lt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "maximum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "minimum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "mish", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "mish_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "mse_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "mul", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "mul_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "ne", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "ne_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "neg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "neg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "nextafter", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "nextafter_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "normal", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten",
+        "poisson_nll_loss",
+        [0, 1],
+        [],
+        ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "prelu", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "rad2deg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "rad2deg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "reciprocal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "reciprocal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "remainder", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "remainder_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "round", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "rsqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "selu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sgn", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sgn_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sigmoid", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sigmoid_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sign", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sign_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "signbit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sinc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sinc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten",
+        "smooth_l1_loss",
+        [0, 1],
+        [],
+        ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "softplus", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "square", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "square_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "sub_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "tan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "tan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "tanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "tanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "threshold", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "threshold_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "true_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "true_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "trunc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "vdot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "xlogy", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+    ElementwiseTypePromotionRule(
+        "aten", "xlogy_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    ),
+}
+
+# Manually curated extra type promotion rules. Please see NOTE [Update type promotion rule]
+# before adding new rules.
+_EXTRA_TYPE_PROMOTION_RULE_SET = {
+    # torch._refs skips type promotion decoration for `clamp_min` and `clamp_max` since
+    # the call is routed to the decorated `aten.clamp` op.
+    ElementwiseTypePromotionRule(
+        "aten",
+        "clamp_max",
+        promote_args_positions=(0, 1),
+        promote_kwargs_names=(),
+        promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+    ),
+    ElementwiseTypePromotionRule(
+        "aten",
+        "clamp_min",
+        promote_args_positions=(0, 1),
+        promote_kwargs_names=(),
+        promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+    ),
+    # torch.ops.aten.div.Tensor_mode applies different type promotion rules
+    # depending on the value of the `mode` argument.
+    DivElementwiseTypePromotionRule(),
+    # Manually curating reduction ops since the logic is written inside the op reference
+    # implementation.
+    AllOrAnyReductionTypePromotionRule("all"),
+    AllOrAnyReductionTypePromotionRule("any"),
+    ReductionTypePromotionRule(
+        "aten",
+        "amax",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    ),
+    ReductionTypePromotionRule(
+        "aten",
+        "amin",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    ),
+    # torch.ops.aten.mean is a special case that does not need type promotion.
+    ReductionTypePromotionRule(
+        "aten",
+        "std",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
+    ),
+    ReductionTypePromotionRule(
+        "aten",
+        "std_mean",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
+    ),
+    ReductionTypePromotionRule(
+        "aten",
+        "var",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
+    ),
+    SumLikeReductionTypePromotionRule(
+        "aten",
+        "cumprod",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    ),
+    SumLikeReductionTypePromotionRule(
+        "aten",
+        "cumsum",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    ),
+    SumLikeReductionTypePromotionRule(
+        "aten",
+        "prod",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    ),
+    SumLikeReductionTypePromotionRule(
+        "aten",
+        "sum",
+        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    ),
+}
+
+
+class ElementwiseTypePromotionRuleSetGenerator:
+    """Hackly distilling info from reference ops decorated with elementwise type promotion rule.
+
+    The goal is to retrieve the decorator
+
+    ```python
+        @elementwise_type_promotion_wrapper(
+            type_promoting_args=("a", "b"),
+            type_promotion_kind=type_promotion_kind,
+        )
+    ```
+
+    from the reference ops. It provides info as for which arguments are promoted
+    and what kind of promotion is applied.
+    """
+
+    @classmethod
+    def generate_from_torch_refs(cls) -> set[ElementwiseTypePromotionRule]:
+        """Parse type promotion rules from reference ops under torch._C._refs."""
+        rule_set = set()
+        rule_set.update(cls._parse_torch_refs(_refs))
+        rule_set.update(cls._parse_torch_refs(_nn_refs))
+        rule_set.update(cls._parse_torch_refs(_linalg_refs))
+        rule_set.update(cls._parse_torch_refs(_special_refs))
+        rule_set.update(cls._parse_torch_refs(_functional_refs))
+        return rule_set
+
+    @classmethod
+    def _parse_torch_refs(
+        cls, ref_module: ModuleType
+    ) -> set[ElementwiseTypePromotionRule]:
+        logger.info("Processing module: %s", ref_module.__name__)
+        rule_set = set()
+        for name in ref_module.__all__:
+            decorated_op = getattr(ref_module, name)
+            rule = cls._parse_type_promotion_rule_from_refs_op(decorated_op)
+            if rule is not None and rule.is_valid():
+                rule_set.add(rule)
+
+        return rule_set
+
+    @classmethod
+    def _parse_type_promotion_rule_from_refs_op(
+        cls,
+        decorated_op: Callable,
+    ) -> ElementwiseTypePromotionRule | None:
+        """Retrieve and parse type promotion decorator from op under torch._refs."""
+        fn = decorated_op
+        type_promo_wrapper = None
+        while fn_closure_vars := _try_getclosurevars(fn):
+            if "fn" not in fn_closure_vars.nonlocals:
+                break
+            if "self" in fn_closure_vars.nonlocals and isinstance(
+                fn_closure_vars.nonlocals["self"],
+                _prims_common_wrappers.elementwise_type_promotion_wrapper,
+            ):
+                type_promo_wrapper = fn_closure_vars.nonlocals["self"]
+                break
+            fn = fn_closure_vars.nonlocals["fn"]
+
+        if type_promo_wrapper is not None:
+            signature = inspect.signature(decorated_op)
+
+            pos = 0
+            promote_args_positions = []
+            promote_kwargs_names = []
+
+            if type_promo_wrapper.type_promoting_arg_names is not None:
+                for name, param in signature.parameters.items():
+                    if name in type_promo_wrapper.type_promoting_arg_names:
+                        if param.kind in (
+                            param.POSITIONAL_OR_KEYWORD,
+                            param.POSITIONAL_ONLY,
+                        ):
+                            promote_args_positions.append(pos)
+                        elif param.kind == param.KEYWORD_ONLY:
+                            promote_kwargs_names.append(name)
+                    pos += 1
+
+            return ElementwiseTypePromotionRule(
+                "aten",
+                decorated_op.__name__,
+                promote_args_positions=promote_args_positions,
+                promote_kwargs_names=promote_kwargs_names,
+                promotion_kind=type_promo_wrapper.type_promotion_kind,
+            )
+
+        logger.warning(
+            "Cannot find type promotion rule for: %s.%s",
+            decorated_op.__module__,
+            decorated_op.__name__,
+        )
+        return None
+
+
+class TypePromotionTable:
+    """Type promotion table for torch.ops."""
+
+    def __init__(self) -> None:
+        self._rule_table = {}
+        for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET:
+            self.add_rule(rule)
+        for rule in _EXTRA_TYPE_PROMOTION_RULE_SET:
+            self.add_rule(rule)
+
+    def add_rule(self, rule: TypePromotionRule) -> None:
+        """Add a type promotion rule for a python op in a torch.ops module.
+
+        Args:
+            rule: Type promotion rule.
+            module: Module containing the op. E.g. torch.ops.aten.
+
+        Raises:
+            ValueError: If the rule is invalid.
+        """
+        if not rule.is_valid():
+            raise ValueError(f"Invalid type promotion rule: {rule}")
+        self._rule_table[f"{rule.namespace}.{rule.op_name}"] = rule
+
+    def get_rule(self, py_op: torch._ops.OpOverloadPacket) -> TypePromotionRule | None:
+        """Get type promotion rule for a python op under 'torch.ops.'."""
+        return self._rule_table.get(str(py_op), None)
+
+
+def get_type_promotion_rule(
+    node: torch.fx.Node,
+    type_promotion_table: TypePromotionTable,
+) -> TypePromotionRule | None:
+    """Get type promotion rule for a node.
+
+    Args:
+        node: Node to get type promotion rule for.
+        type_promotion_table: Type promotion table.
+
+    Returns:
+        Type promotion rule for the node. None if no rule is found or if the node is not
+        representing a torch operator.
+    """
+    op = node.target
+    if not isinstance(op, torch._ops.OpOverload):
+        return None
+    if (rule := type_promotion_table.get_rule(op.overloadpacket)) is None:
+        return None
+
+    return rule
+
+
+class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode):
+    """Trace ops that were dispatched.
+
+    Utilize the dispatch mechanism in [`__torch_dispatch__`](https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557)
+    to trace op overloads that were dispatched to. This is used to find the compatible
+    op overload for a given op overload packet for different set of args and kwargs.
+    """
+
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.traced_ops = []
+
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        self.traced_ops.append(func)
+        return func(*args, **kwargs)
+
+
+def find_compatible_op_overload(
+    op: torch._ops.OpOverloadPacket, args: tuple, kwargs: dict
+) -> torch._ops.OpOverload:
+    """Find compatible OpOverload for an OpOverloadPacket using provided args and kwargs.
+
+    Each "call_function" fx.Node in the fx.GraphModule has a target that represents a torch._ops.OpOverload.
+    The OpOverload contains an OpOverloadPacket that holds all the available overloads for the operation.
+
+    During the type promotion pass, there are cases where the types of the args and kwargs may change,
+    such as promoting Python numbers to tensors. Consequently, the original OpOverload might not be
+    compatible with the updated args and kwargs. This function is used to identify the compatible
+    OpOverload for the given args and kwargs.
+
+    Args:
+        op: OpOverloadPacket to find compatible OpOverload for.
+        args: The positional arguments to consider for compatibility.
+        kwargs: The keyword arguments to consider for compatibility.
+
+    Returns:
+        torch._ops.OpOverload: The compatible OpOverload found for the given args and kwargs.
+
+    Raises:
+        RuntimeError: If no compatible op overload is found.
+
+    Examples:
+        >>> import torch
+        >>> packet = torch.ops.aten.pow
+        >>> args = (torch.tensor([1.0, 2.0]), 2)
+        >>> find_compatible_op_overload(packet, args, {})._overloadname
+        'Tensor_Scalar'
+        >>> args = (torch.tensor([1.0, 2.0]), torch.tensor(2.0))
+        >>> find_compatible_op_overload(packet, args, {})._overloadname
+        'Tensor_Tensor'
+    """
+    # Utilize the dispatch mechanism to find the compatible op overload.
+    op_trace_dispatch_mode = _OpTraceDispatchMode()
+    with op_trace_dispatch_mode:
+        op(*args, **kwargs)
+    assert len(op_trace_dispatch_mode.traced_ops) >= 1, (
+        "Expected at least 1 traced op, got 0"
+    )
+
+    new_op_overload = op_trace_dispatch_mode.traced_ops[0]
+    assert isinstance(new_op_overload, torch._ops.OpOverload), (
+        f"Expected OpOverload, got {type(new_op_overload)}"
+    )
+    assert new_op_overload.overloadpacket == op, (
+        f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}"
+    )
+
+    return new_op_overload
+
+
+class _TypePromotionInterpreter(torch.fx.Interpreter):
+    """Interpreter that inserts type promotion for each node."""
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        type_promotion_table: TypePromotionTable,
+    ) -> None:
+        super().__init__(module)
+        self.type_promotion_table = type_promotion_table
+
+    def _run_node_and_set_meta(self, node) -> Any:
+        """Run node and set meta according to `fx_traceback.get_current_meta()`.
+
+        This should be used on new nodes or nodes that have been modified.
+        By default `Interpreter.run_node` does not update `node.meta`.
+        Set `node.meta` to the current meta, except for `node.meta["val"]`, which is
+        recomputed.
+        """
+        out = super().run_node(node)
+        # Update interpreter env state with new output value.
+        self.env[node] = out
+        node.meta.update(
+            (k, v)
+            for k, v in fx_traceback.get_current_meta().items()
+            if k not in node.meta
+        )
+        node.meta["val"] = proxy_tensor.extract_val(out)
+        return out
+
+    def _create_node(
+        self,
+        graph: torch.fx.Graph,
+        op_type: str,
+        target: torch.fx.node.Target,
+        args: tuple,
+        kwargs: dict,
+    ) -> torch.fx.Node:
+        """Create a node and set its metadata."""
+        assert op_type in (
+            "call_function",
+            "call_method",
+            "get_attr",
+            "call_module",
+            "placeholder",
+            "output",
+        ), f"Unexpected op_type: {op_type}"
+        node = getattr(graph, op_type)(target, args, kwargs)
+        self._run_node_and_set_meta(node)
+        return node
+
+    def _rerun_node_after_type_promotion(
+        self,
+        node: torch.fx.Node,
+        expected_out_dtype: torch.dtype,
+    ) -> None:
+        """Rerun a node after type promotion and update node.meta["val"] with the output value."""
+        node_val = node.meta.get("val", None)
+        assert node_val is not None, f"Node {node} node.meta['val'] is not set."
+        args, kwargs = self.fetch_args_kwargs_from_env(node)
+        target = node.target
+        assert isinstance(target, torch._ops.OpOverload), (
+            f"Expected OpOverload, got {type(target)}"
+        )
+        node.target = find_compatible_op_overload(target.overloadpacket, args, kwargs)
+
+        new_node_val = self._run_node_and_set_meta(node)
+        assert isinstance(new_node_val, type(node_val)), (
+            f"run_node output type should not change between runs. "
+            f"Got {type(new_node_val)}, expect {type(node_val)}."
+        )
+
+        if isinstance(node_val, torch.Tensor):
+            prev_node_dtype = node_val.dtype
+
+            assert prev_node_dtype == expected_out_dtype, (
+                f"node.meta['val'].dtype({prev_node_dtype}) does not agree with "
+                f"type promotion rule({expected_out_dtype})."
+            )
+
+            if new_node_val.dtype != expected_out_dtype:
+                # With explicit type promotion, the expected result dtype may not be
+                # the same as the computation dtype. This is referred to as "op math".
+                # We need to explicitly cast the output back to the expected dtype.
+                # See more about "op math" topic at `_prims_common.elementwise_dtypes`.
+                graph = node.graph
+                with graph.inserting_after(node):
+                    output_cast_node = self._create_node(
+                        graph,
+                        "call_function",
+                        torch.ops.prims.convert_element_type.default,
+                        (node,),
+                        {"dtype": expected_out_dtype},
+                    )
+                    node.replace_all_uses_with(output_cast_node)
+                    output_cast_node.args = (node,)
+                    logger.info(
+                        "Node '%s' output dtype becomes %s due to op math. "
+                        "Cast back to %s.",
+                        node,
+                        new_node_val.dtype,
+                        expected_out_dtype,
+                    )
+
+        elif fx_type_utils.is_torch_symbolic_type(node_val):
+            raise NotImplementedError(
+                "Type promotion does not support node output of sym types."
+            )
+        elif isinstance(node_val, (list, tuple)):
+            raise NotImplementedError(
+                "Type promotion does not support node output of list or tuple."
+            )
+        else:
+            raise RuntimeError(f"Unexpected node output type: {type(node_val)}.")
+
+    def _maybe_promote_arg(
+        self,
+        node: torch.fx.Node,
+        fx_arg: torch.fx.node.Argument,
+        dtype: torch.dtype | None,
+    ) -> torch.fx.node.Argument:
+        """Promote fx_arg to dtype if necessary."""
+        if dtype is None:
+            logger.info(
+                "Argument %s is not promoted. Not mentioned by type promotion rule.",
+                fx_arg,
+            )
+            return fx_arg
+
+        if isinstance(fx_arg, torch.fx.Node):
+            arg_val = self.env[fx_arg]
+            if isinstance(arg_val, torch.Tensor):
+                if (old_dtype := arg_val.dtype) != dtype:
+                    # Promote tensor to dtype.
+                    graph = node.graph
+                    with graph.inserting_before(node):
+                        logger.info(
+                            "Argument %s(%s) is promoted to %s.",
+                            fx_arg,
+                            old_dtype,
+                            dtype,
+                        )
+                        return self._create_node(
+                            graph,
+                            "call_function",
+                            torch.ops.prims.convert_element_type.default,
+                            (fx_arg,),
+                            {"dtype": dtype},
+                        )
+                logger.info("Argument %s is not promoted. Already %s.", fx_arg, dtype)
+                return fx_arg
+            elif fx_type_utils.is_torch_symbolic_type(arg_val):
+                arg_type = type(arg_val)
+                equivalent_dtype = fx_type_utils.from_scalar_type_to_torch_dtype(
+                    arg_type
+                )
+                assert equivalent_dtype is not None, f"Unexpected arg_type: {arg_type}"
+                if equivalent_dtype != dtype:
+                    # Promote Sym number to tensor of dtype.
+                    graph = node.graph
+                    with graph.inserting_before(node):
+                        logger.info(
+                            "Argument %s(Scalar of equivalent dtype: %s) "
+                            "is promoted to %s.",
+                            fx_arg,
+                            equivalent_dtype,
+                            dtype,
+                        )
+                        return self._create_node(
+                            graph,
+                            "call_function",
+                            torch.ops.aten.scalar_tensor.default,
+                            (fx_arg,),
+                            {"dtype": dtype},
+                        )
+                logger.info("Argument %s is not promoted. Already %s.", fx_arg, dtype)
+                return fx_arg
+        elif (
+            equivalent_dtype := fx_type_utils.from_scalar_type_to_torch_dtype(
+                type(fx_arg)
+            )
+        ) is not None:
+            if equivalent_dtype != dtype:
+                # Promote number to tensor of dtype.
+                # The op should have overload that supports tensor for this arg, otherwise
+                # the type promotion rule should not suggest promoting this arg.
+                graph = node.graph
+                with graph.inserting_before(node):
+                    logger.info(
+                        "Argument %s(Scalar of equivalent dtype: %s) "
+                        "is promoted to %s.",
+                        fx_arg,
+                        equivalent_dtype,
+                        dtype,
+                    )
+                    return self._create_node(
+                        graph,
+                        "call_function",
+                        torch.ops.aten.scalar_tensor.default,
+                        (fx_arg,),
+                        {"dtype": dtype},
+                    )
+            logger.info("Argument %s is not promoted. Already %s.", fx_arg, dtype)
+            return fx_arg
+        elif isinstance(fx_arg, (tuple, list)):
+            logger.info("Argument %s is a tuple/list. Promoting each element.", fx_arg)
+            return type(fx_arg)(
+                self._maybe_promote_arg(node, fx_arg_elem, dtype)
+                for fx_arg_elem in fx_arg
+            )
+
+        raise NotImplementedError(f"Unknown fx arg type: {type(fx_arg)}")
+
+    def _maybe_promote_node(
+        self,
+        node: torch.fx.Node,
+        rule: TypePromotionRule,
+    ) -> torch.fx.Node:
+        """Promote node inputs and outputs according to type promotion rule."""
+        args, kwargs = self.fetch_args_kwargs_from_env(node)
+        type_promotion_info = rule.preview_type_promotion(args, kwargs)
+        new_args = []
+        new_kwargs = {}
+        for i, arg in enumerate(node.args):
+            new_args.append(
+                self._maybe_promote_arg(
+                    node, arg, type_promotion_info.args_dtypes.get(i, None)
+                )
+            )
+
+        for name, arg in node.kwargs.items():
+            new_kwargs[name] = self._maybe_promote_arg(
+                node, arg, type_promotion_info.kwargs_dtypes.get(name, None)
+            )
+        new_args = tuple(new_args)
+
+        if node.args != new_args or node.kwargs != new_kwargs:
+            node.args = new_args
+            node.kwargs = new_kwargs
+            self._rerun_node_after_type_promotion(node, type_promotion_info.out_dtype)
+
+        return node
+
+    def run_node(self, n: torch.fx.Node) -> Any:
+        """This method is an override which inserts type promotion nodes as needed.
+
+        For each `call_function` node, an initial check is conducted to determine if a type
+        promotion rule is applicable. If a relevant rule exists, type casting nodes are
+        introduced for the corresponding arguments. The OpOverload of the node is updated
+        to one that accommodates the promoted types. Should the output type be different,
+        type casting node is inserted for this output.
+
+        The call `super().run_node(node)` is guaranteed to be invoked for each node.
+        In the case of new or modified nodes, the result of `super().run_node(node)` is
+        used to update its `node.meta["val"]` value.
+        """
+        with self._set_current_node(n):
+            if rule := get_type_promotion_rule(n, self.type_promotion_table):
+                self._maybe_promote_node(n, rule)
+
+        return super().run_node(n)
+
+
+class InsertTypePromotion(_pass.Transform):
+    """Explicitly insert type promotion ops to the graph.
+
+    Underneath, the main pass is driven by `_TypePromotionInterpreter`, which is a subclass
+    of `torch.fx.Interpreter` to interpret the fx.Graph and perform the insertion of type
+    promotion operations.
+
+    By re-running the new and modified nodes using the interpreter, we can update the
+    metadata, specifically the fake tensor stored under node.meta["val"], and ensure it
+    reflects the latest changes.
+    """
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        type_promotion_table: TypePromotionTable | None = None,
+    ) -> None:
+        super().__init__(module)
+        self.interpreter = _TypePromotionInterpreter(
+            module, type_promotion_table or TypePromotionTable()
+        )
+
+    def _fetch_fake_args(
+        self,
+    ) -> Sequence[
+        fake_tensor.FakeTensor
+        | float
+        | int
+        | bool
+        | torch.SymInt
+        | torch.SymFloat
+        | torch.SymBool
+        | None
+    ]:
+        """Fetch fake args from fx graph.
+
+        For each argument, try to fetch fake tensor from the matching placeholder node.
+        """
+        fake_args = []
+        for node in self.module.graph.nodes:
+            if node.op == "placeholder":
+                try:
+                    # Meta value can be torch.Tensor, int, float, bool,
+                    # torch.SymInt, torch.SymFloat, torch.SymBool.
+                    meta_value = _val = node.meta.get("val", None)
+                except RuntimeError as e:
+                    if not node.users:
+                        # If the placeholder is not used, we can safely ignore it and put
+                        # None as placeholder.
+                        meta_value = None
+                    else:
+                        raise RuntimeError(
+                            "Cannot fetch symbolic fake args from fx graph. "
+                            "InsertTypePromotion pass needs to run with pre-existing fake args, "
+                            "Otherwise the pass will produce inaccurate dynamic shape. "
+                        ) from e
+
+                fake_args.append(meta_value)
+        return fake_args
+
+    def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
+        assert not args, (
+            "`InsertTypePromotion` deduces symbolic fake arguments from the graph. "
+            "It does not accept concrete arguments as input because this pass requires "
+            "re-running the graph. When executed with newly faked concrete arguments, "
+            "the pass loses the symbolic dynamic shape information."
+        )
+        assert not kwargs, "`kwargs` is not supported"
+
+        fake_args = self._fetch_fake_args()
+        fake_mode = self.fake_mode
+        assert fake_mode is not None, "Cannot detect fake_mode."
+
+        # Use the python dispatcher to run through some python kernels which
+        # can better handle symints. Without this, some SymInts can become static
+        # when there are dynamic shapes.
+        dispatcher_mode = torch._dispatch.python.enable_python_dispatcher()
+        with fake_mode, dispatcher_mode, fx_traceback.preserve_node_meta():
+            self.interpreter.run(*fake_args)
+
+        return self.module
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dd2b53385176ec36b169bb8c5e938772d46c710
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_dtype_mappings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_dtype_mappings.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8aa9deb6da9e491e6d316eadc6645d538190921b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_dtype_mappings.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_impl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89d5280ea3bb1ec0030c7a9b19fef0801cce2734
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_impl.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_symbolic_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_symbolic_impl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07071baa4e03d66d992a15e5ac789cba07391e13
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/onnx/ops/__pycache__/_symbolic_impl.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddcd4a54a079da41b2d7cef93e60bde5f822a585
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77c9518321559f4c00f87dcdf474d51ddb46a870
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a182fe84f6c032af6d3ce50d2fdda43663f39305
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ca413d53cef63080796a1194e560c7703ad5acf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6935d9d24ad1f8eefcf0361cd423ca2555b31dbd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53d8752b2a202387f50f50b37f171eca5edf4689
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1df629784dcce4e35fca1b72df55ca947a04c9a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eaf7e399f0149216c5adb8b7eb3e4e0704a187f5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..948cc6d2fce917d6c2ff5eef064dcc0d9947f067
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e306526c3c25aea551556473c6102a401049e6d0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ea91eee1ad0d42320defe89e9be9b001cb1168a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85df37525b7e6e527247eb79273856ddd91bb9a4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69775369d21ad048d42a9b5f2f21145230f20015
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/ATen/ATenConfig.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/ATen/ATenConfig.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..0ce7803dbf78897298d81c2679f2cdb3c872bc15
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/ATen/ATenConfig.cmake
@@ -0,0 +1,9 @@
+# Find the TH includes and library
+#
+# ATEN_INCLUDE_DIR -- where to find the includes
+# ATEN_LIBRARIES -- list of libraries to link against
+# ATEN_FOUND -- set to 1 if found
+
+set(ATEN_FOUND 1)
+set(ATEN_INCLUDE_DIR "/pytorch/torch/include")
+set(ATEN_LIBRARIES "")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..2457dff032a8b824d173fe1cb2d4e787a7b9839c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake
@@ -0,0 +1,140 @@
+# - Config file for the Caffe2 package
+# It defines the following variable(s)
+#   CAFFE2_INCLUDE_DIRS     - include directories for FooBar
+# as well as Caffe2 targets for other cmake libraries to use.
+
+# library version information
+
+# Utils functions.
+include("${CMAKE_CURRENT_LIST_DIR}/public/utils.cmake")
+
+# Depending on whether Caffe2 uses gflags during compile time or
+# not, invoke gflags.
+if(OFF)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/gflags.cmake")
+  if(NOT TARGET gflags)
+    message(FATAL_ERROR
+        "Your installed Caffe2 version uses gflags but the gflags library "
+        "cannot be found. Did you accidentally remove it, or have you set "
+        "the right CMAKE_PREFIX_PATH and/or GFLAGS_ROOT_DIR? If you do not "
+        "have gflags, you will need to install gflags and set the library "
+        "path accordingly.")
+  endif()
+endif()
+
+# Depending on whether Caffe2 uses glog during compile time or
+# not, invoke glog.
+if(OFF)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/glog.cmake")
+  if(NOT TARGET glog::glog)
+    message(FATAL_ERROR
+        "Your installed Caffe2 version uses glog but the glog library "
+        "cannot be found. Did you accidentally remove it, or have you set "
+        "the right CMAKE_PREFIX_PATH and/or GFLAGS_ROOT_DIR? If you do not "
+        "have glog, you will need to install glog and set the library "
+        "path accordingly.")
+  endif()
+endif()
+
+# Protobuf
+if(ON)
+  if(NOT TARGET protobuf::libprotobuf)
+    # Define protobuf::libprotobuf as a dummy target to resolve references to
+    # protobuf::libprotobuf in Caffe2Targets.cmake.
+    add_library(dummy INTERFACE)
+    add_library(protobuf::libprotobuf ALIAS dummy)
+  endif()
+else()
+  include("${CMAKE_CURRENT_LIST_DIR}/public/protobuf.cmake")
+  if(NOT TARGET protobuf::libprotobuf)
+    message(FATAL_ERROR
+        "Your installed Caffe2 version uses protobuf but the protobuf library "
+        "cannot be found. Did you accidentally remove it, or have you set "
+        "the right CMAKE_PREFIX_PATH? If you do not have protobuf, you will "
+        "need to install protobuf and set the library path accordingly.")
+  endif()
+  message(STATUS "Caffe2: Protobuf version " ${Protobuf_VERSION})
+  # If during build time we know the protobuf version, we will also do a sanity
+  # check to ensure that the protobuf library that Caffe2 found is consistent
+  # with the compiled version.
+  if(FALSE)
+    if(NOT (${Protobuf_VERSION} VERSION_EQUAL Protobuf_VERSION_NOTFOUND))
+      message(FATAL_ERROR
+          "Your installed Caffe2 is built with protobuf "
+          "Protobuf_VERSION_NOTFOUND"
+          ", while your current cmake setting discovers protobuf version "
+          ${Protobuf_VERSION}
+          ". Please specify a protobuf version that is the same as the built "
+          "version.")
+    endif()
+  endif()
+endif()
+
+if (OFF)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake")
+endif()
+
+if(ON)
+  # The file public/cuda.cmake exclusively uses CAFFE2_USE_*.
+  # If Caffe2 was compiled with the libraries below, they must
+  # be found again when including the Caffe2 target.
+  set(CAFFE2_USE_CUDA ON)
+
+  # Add current directory to module path so we pick up FindCUDAToolkit.cmake
+  set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}")
+  list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}")
+  include("${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake")
+  set(CMAKE_MODULE_PATH "${old_CMAKE_MODULE_PATH}")
+
+  if(ON AND NOT CAFFE2_USE_CUDA)
+    message(FATAL_ERROR
+      "Your installed Caffe2 version uses CUDA but I cannot find the CUDA "
+      "libraries. Please set the proper CUDA prefixes and / or install "
+      "CUDA.")
+  endif()
+endif()
+
+if(OFF)
+  # Add current directory to module path so we pick up FindSYCLToolkit.cmake
+  set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}")
+  list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}")
+  include("${CMAKE_CURRENT_LIST_DIR}/public/xpu.cmake")
+  set(CMAKE_MODULE_PATH "${old_CMAKE_MODULE_PATH}")
+
+  if(OFF AND NOT PYTORCH_FOUND_XPU)
+    message(FATAL_ERROR
+      "Your installed Caffe2 version uses XPU but I cannot find the XPU runtime"
+      "libraries. Please set the proper oneAPI paths and / or install "
+      "oneAPI.")
+  endif()
+endif()
+
+if(ON)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/mkl.cmake")
+endif()
+
+if(ON)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/mkldnn.cmake")
+endif()
+
+# import targets
+include ("${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets.cmake")
+
+# Interface libraries, that allows one to build proper link flags.
+# We will also define a helper variable, Caffe2_MAIN_LIBS, that resolves to
+# the main caffe2 libraries in cases of cuda presence / absence.
+set(Caffe2_MAIN_LIBS torch_library)
+
+# include directory.
+#
+# Newer versions of CMake set the INTERFACE_INCLUDE_DIRECTORIES property
+# of the imported targets. It is hence not necessary to add this path
+# manually to the include search path for targets which link to gflags.
+# The following lines are here for backward compatibility, in case one
+# would like to use the old-style include path.
+get_filename_component(
+    CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
+# Note: the current list dir is _INSTALL_PREFIX/share/cmake/Gloo.
+get_filename_component(
+    _INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
+set(CAFFE2_INCLUDE_DIRS "${_INSTALL_PREFIX}/include")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..721afaa1b956f721ecd584a69ae59de56f5e5064
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake
@@ -0,0 +1,71 @@
+#----------------------------------------------------------------
+# Generated CMake target import file for configuration "Release".
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Import target "c10_cuda" for configuration "Release"
+set_property(TARGET c10_cuda APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(c10_cuda PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libc10_cuda.so"
+  IMPORTED_SONAME_RELEASE "libc10_cuda.so"
+  )
+
+list(APPEND _cmake_import_check_targets c10_cuda )
+list(APPEND _cmake_import_check_files_for_c10_cuda "${_IMPORT_PREFIX}/lib/libc10_cuda.so" )
+
+# Import target "c10" for configuration "Release"
+set_property(TARGET c10 APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(c10 PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libc10.so"
+  IMPORTED_SONAME_RELEASE "libc10.so"
+  )
+
+list(APPEND _cmake_import_check_targets c10 )
+list(APPEND _cmake_import_check_files_for_c10 "${_IMPORT_PREFIX}/lib/libc10.so" )
+
+# Import target "torch_nvshmem" for configuration "Release"
+set_property(TARGET torch_nvshmem APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch_nvshmem PROPERTIES
+  IMPORTED_LINK_DEPENDENT_LIBRARIES_RELEASE "torch_cpu"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch_nvshmem.so"
+  IMPORTED_SONAME_RELEASE "libtorch_nvshmem.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch_nvshmem )
+list(APPEND _cmake_import_check_files_for_torch_nvshmem "${_IMPORT_PREFIX}/lib/libtorch_nvshmem.so" )
+
+# Import target "torch_cpu" for configuration "Release"
+set_property(TARGET torch_cpu APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch_cpu PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch_cpu.so"
+  IMPORTED_SONAME_RELEASE "libtorch_cpu.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch_cpu )
+list(APPEND _cmake_import_check_files_for_torch_cpu "${_IMPORT_PREFIX}/lib/libtorch_cpu.so" )
+
+# Import target "torch_cuda" for configuration "Release"
+set_property(TARGET torch_cuda APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch_cuda PROPERTIES
+  IMPORTED_LINK_DEPENDENT_LIBRARIES_RELEASE "torch_nvshmem"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch_cuda.so"
+  IMPORTED_SONAME_RELEASE "libtorch_cuda.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch_cuda )
+list(APPEND _cmake_import_check_files_for_torch_cuda "${_IMPORT_PREFIX}/lib/libtorch_cuda.so" )
+
+# Import target "torch" for configuration "Release"
+set_property(TARGET torch APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch.so"
+  IMPORTED_SONAME_RELEASE "libtorch.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch )
+list(APPEND _cmake_import_check_files_for_torch "${_IMPORT_PREFIX}/lib/libtorch.so" )
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..086cc1e2547c8f2ba2536d918a6676f65f38f56a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake
@@ -0,0 +1,200 @@
+# Generated by CMake
+
+if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
+   message(FATAL_ERROR "CMake >= 3.0.0 required")
+endif()
+if(CMAKE_VERSION VERSION_LESS "3.0.0")
+   message(FATAL_ERROR "CMake >= 3.0.0 required")
+endif()
+cmake_policy(PUSH)
+cmake_policy(VERSION 3.0.0...4.0)
+#----------------------------------------------------------------
+# Generated CMake target import file.
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Protect against multiple inclusion, which would fail when already imported targets are added once more.
+set(_cmake_targets_defined "")
+set(_cmake_targets_not_defined "")
+set(_cmake_expected_targets "")
+foreach(_cmake_expected_target IN ITEMS headeronly c10_cuda c10 torch_nvshmem torch_cpu torch_cpu_library torch_cuda torch_cuda_library torch torch_library)
+  list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
+  if(TARGET "${_cmake_expected_target}")
+    list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
+  else()
+    list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
+  endif()
+endforeach()
+unset(_cmake_expected_target)
+if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
+  unset(_cmake_targets_defined)
+  unset(_cmake_targets_not_defined)
+  unset(_cmake_expected_targets)
+  unset(CMAKE_IMPORT_FILE_VERSION)
+  cmake_policy(POP)
+  return()
+endif()
+if(NOT _cmake_targets_defined STREQUAL "")
+  string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
+  string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
+  message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
+endif()
+unset(_cmake_targets_defined)
+unset(_cmake_targets_not_defined)
+unset(_cmake_expected_targets)
+
+
+# Compute the installation prefix relative to this file.
+get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+if(_IMPORT_PREFIX STREQUAL "/")
+  set(_IMPORT_PREFIX "")
+endif()
+
+# Create imported target headeronly
+add_library(headeronly INTERFACE IMPORTED)
+
+# Create imported target c10_cuda
+add_library(c10_cuda SHARED IMPORTED)
+
+set_target_properties(c10_cuda PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "c10;torch::cudart"
+)
+
+# Create imported target c10
+add_library(c10 SHARED IMPORTED)
+
+set_target_properties(c10 PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "headeronly"
+)
+
+# Create imported target torch_nvshmem
+add_library(torch_nvshmem SHARED IMPORTED)
+
+set_target_properties(torch_nvshmem PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "USE_NVSHMEM"
+)
+
+# Create imported target torch_cpu
+add_library(torch_cpu SHARED IMPORTED)
+
+set_target_properties(torch_cpu PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "USE_DISTRIBUTED;USE_C10D_GLOO;USE_RPC;USE_TENSORPIPE"
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "protobuf::libprotobuf;c10;caffe2::mkl"
+)
+
+# Create imported target torch_cpu_library
+add_library(torch_cpu_library INTERFACE IMPORTED)
+
+set_target_properties(torch_cpu_library PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "\$"
+  INTERFACE_COMPILE_OPTIONS "\$"
+  INTERFACE_INCLUDE_DIRECTORIES "\$"
+  INTERFACE_LINK_LIBRARIES "-Wl,--no-as-needed,\"\$\" -Wl,--as-needed;\$"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$"
+)
+
+# Create imported target torch_cuda
+add_library(torch_cuda SHARED IMPORTED)
+
+set_target_properties(torch_cuda PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "USE_NVSHMEM;USE_C10D_NCCL"
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "torch::cudart;c10_cuda;torch_cpu_library"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "include"
+)
+
+# Create imported target torch_cuda_library
+add_library(torch_cuda_library INTERFACE IMPORTED)
+
+set_target_properties(torch_cuda_library PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "\$"
+  INTERFACE_COMPILE_OPTIONS "\$"
+  INTERFACE_INCLUDE_DIRECTORIES "\$"
+  INTERFACE_LINK_LIBRARIES "-Wl,--no-as-needed,\"\$\" -Wl,--as-needed;\$"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$"
+)
+
+# Create imported target torch
+add_library(torch SHARED IMPORTED)
+
+set_target_properties(torch PROPERTIES
+  INTERFACE_LINK_LIBRARIES "torch_cpu_library;torch_cuda_library"
+)
+
+# Create imported target torch_library
+add_library(torch_library INTERFACE IMPORTED)
+
+set_target_properties(torch_library PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "\$"
+  INTERFACE_COMPILE_OPTIONS "\$"
+  INTERFACE_INCLUDE_DIRECTORIES "\$"
+  INTERFACE_LINK_LIBRARIES "-Wl,--no-as-needed,\"\$\" -Wl,--as-needed;\$"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$"
+)
+
+# Load information for each installed configuration.
+file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets-*.cmake")
+foreach(_cmake_config_file IN LISTS _cmake_config_files)
+  include("${_cmake_config_file}")
+endforeach()
+unset(_cmake_config_file)
+unset(_cmake_config_files)
+
+# Cleanup temporary variables.
+set(_IMPORT_PREFIX)
+
+# Loop over all imported files and verify that they actually exist
+foreach(_cmake_target IN LISTS _cmake_import_check_targets)
+  if(CMAKE_VERSION VERSION_LESS "3.28"
+      OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
+      OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
+    foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
+      if(NOT EXISTS "${_cmake_file}")
+        message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
+   \"${_cmake_file}\"
+but this file does not exist.  Possible reasons include:
+* The file was deleted, renamed, or moved to another location.
+* An install or uninstall procedure did not complete successfully.
+* The installation package was faulty and contained
+   \"${CMAKE_CURRENT_LIST_FILE}\"
+but not all the files it references.
+")
+      endif()
+    endforeach()
+  endif()
+  unset(_cmake_file)
+  unset("_cmake_import_check_files_for_${_cmake_target}")
+endforeach()
+unset(_cmake_target)
+unset(_cmake_import_check_targets)
+
+# Make sure the targets which have been exported in some other
+# export set exist.
+unset(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets)
+foreach(_target "protobuf::libprotobuf" )
+  if(NOT TARGET "${_target}" )
+    set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets "${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets} ${_target}")
+  endif()
+endforeach()
+
+if(DEFINED ${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets)
+  if(CMAKE_FIND_PACKAGE_NAME)
+    set( ${CMAKE_FIND_PACKAGE_NAME}_FOUND FALSE)
+    set( ${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE "The following imported targets are referenced, but are missing: ${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets}")
+  else()
+    message(FATAL_ERROR "The following imported targets are referenced, but are missing: ${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets}")
+  endif()
+endif()
+unset(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets)
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
+cmake_policy(POP)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..ec9ae530aa6b2bdceb87f966e706fb5c2a36349a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake
@@ -0,0 +1,1081 @@
+
+# This module is back-ported from CMake 3.17 and above to work with CMake 3.10
+
+# Distributed under the OSI-approved BSD 3-Clause License.  See accompanying
+# file Copyright.txt or https://cmake.org/licensing for details.
+
+#[=======================================================================[.rst:
+FindCUDAToolkit
+---------------
+
+.. versionadded:: 3.17
+
+This script locates the NVIDIA CUDA toolkit and the associated libraries, but
+does not require the ``CUDA`` language be enabled for a given project. This
+module does not search for the NVIDIA CUDA Samples.
+
+.. versionadded:: 3.19
+  QNX support.
+
+Search Behavior
+^^^^^^^^^^^^^^^
+
+The CUDA Toolkit search behavior uses the following order:
+
+1. If the ``CUDA`` language has been enabled we will use the directory
+   containing the compiler as the first search location for ``nvcc``.
+
+2. If the ``CUDAToolkit_ROOT`` cmake configuration variable (e.g.,
+   ``-DCUDAToolkit_ROOT=/some/path``) *or* environment variable is defined, it
+   will be searched.  If both an environment variable **and** a
+   configuration variable are specified, the *configuration* variable takes
+   precedence.
+
+   The directory specified here must be such that the executable ``nvcc`` or
+   the appropriate ``version.txt`` file can be found underneath the specified
+   directory.
+
+3. If the CUDA_PATH environment variable is defined, it will be searched
+   for ``nvcc``.
+
+4. The user's path is searched for ``nvcc`` using :command:`find_program`.  If
+   this is found, no subsequent search attempts are performed.  Users are
+   responsible for ensuring that the first ``nvcc`` to show up in the path is
+   the desired path in the event that multiple CUDA Toolkits are installed.
+
+5. On Unix systems, if the symbolic link ``/usr/local/cuda`` exists, this is
+   used.  No subsequent search attempts are performed.  No default symbolic link
+   location exists for the Windows platform.
+
+6. The platform specific default install locations are searched.  If exactly one
+   candidate is found, this is used.  The default CUDA Toolkit install locations
+   searched are:
+
+   +-------------+-------------------------------------------------------------+
+   | Platform    | Search Pattern                                              |
+   +=============+=============================================================+
+   | macOS       | ``/Developer/NVIDIA/CUDA-X.Y``                              |
+   +-------------+-------------------------------------------------------------+
+   | Other Unix  | ``/usr/local/cuda-X.Y``                                     |
+   +-------------+-------------------------------------------------------------+
+   | Windows     | ``C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y`` |
+   +-------------+-------------------------------------------------------------+
+
+   Where ``X.Y`` would be a specific version of the CUDA Toolkit, such as
+   ``/usr/local/cuda-9.0`` or
+   ``C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0``
+
+   .. note::
+
+       When multiple CUDA Toolkits are installed in the default location of a
+       system(e.g., both ``/usr/local/cuda-9.0`` and ``/usr/local/cuda-10.0``
+       exist but the ``/usr/local/cuda`` symbolic link does **not** exist), this
+       package is marked as **not** found.
+
+       There are too many factors involved in making an automatic decision in
+       the presence of multiple CUDA Toolkits being installed.  In this
+       situation, users are encouraged to either (1) set ``CUDAToolkit_ROOT`` or
+       (2) ensure that the correct ``nvcc`` executable shows up in ``$PATH`` for
+       :command:`find_program` to find.
+
+Arguments
+^^^^^^^^^
+
+``[]``
+    The ``[]`` argument requests a version with which the package found
+    should be compatible. See :ref:`find_package version format `
+    for more details.
+
+Options
+^^^^^^^
+
+``REQUIRED``
+    If specified, configuration will error if a suitable CUDA Toolkit is not
+    found.
+
+``QUIET``
+    If specified, the search for a suitable CUDA Toolkit will not produce any
+    messages.
+
+``EXACT``
+    If specified, the CUDA Toolkit is considered found only if the exact
+    ``VERSION`` specified is recovered.
+
+Imported targets
+^^^^^^^^^^^^^^^^
+
+An :ref:`imported target ` named ``CUDA::toolkit`` is provided.
+
+This module defines :prop_tgt:`IMPORTED` targets for each
+of the following libraries that are part of the CUDAToolkit:
+
+- :ref:`CUDA Runtime Library`
+- :ref:`CUDA Driver Library`
+- :ref:`cuBLAS`
+- :ref:`cuFFT`
+- :ref:`cuRAND`
+- :ref:`cuSOLVER`
+- :ref:`cuSPARSE`
+- :ref:`cuPTI`
+- :ref:`NPP`
+- :ref:`nvBLAS`
+- :ref:`nvGRAPH`
+- :ref:`nvJPEG`
+- :ref:`nvidia-ML`
+- :ref:`nvRTC`
+- :ref:`nvToolsExt`
+- :ref:`OpenCL`
+- :ref:`cuLIBOS`
+
+.. _`cuda_toolkit_rt_lib`:
+
+CUDA Runtime Library
+""""""""""""""""""""
+
+The CUDA Runtime library (cudart) are what most applications will typically
+need to link against to make any calls such as `cudaMalloc`, and `cudaFree`.
+
+Targets Created:
+
+- ``CUDA::cudart``
+- ``CUDA::cudart_static``
+
+.. _`cuda_toolkit_driver_lib`:
+
+CUDA Driver Library
+""""""""""""""""""""
+
+The CUDA Driver library (cuda) are used by applications that use calls
+such as `cuMemAlloc`, and `cuMemFree`.
+
+Targets Created:
+
+- ``CUDA::cuda_driver``
+
+.. _`cuda_toolkit_cuBLAS`:
+
+cuBLAS
+""""""
+
+The `cuBLAS `_ library.
+
+Targets Created:
+
+- ``CUDA::cublas``
+- ``CUDA::cublas_static``
+- ``CUDA::cublasLt`` starting in CUDA 10.1
+- ``CUDA::cublasLt_static`` starting in CUDA 10.1
+
+.. _`cuda_toolkit_cuFFT`:
+
+cuFFT
+"""""
+
+The `cuFFT `_ library.
+
+Targets Created:
+
+- ``CUDA::cufft``
+- ``CUDA::cufftw``
+- ``CUDA::cufft_static``
+- ``CUDA::cufft_static_nocallback`` starting in CUDA 9.2, requires CMake 3.23+
+- ``CUDA::cufftw_static``
+
+cuRAND
+""""""
+
+The `cuRAND `_ library.
+
+Targets Created:
+
+- ``CUDA::curand``
+- ``CUDA::curand_static``
+
+.. _`cuda_toolkit_cuSOLVER`:
+
+cuSOLVER
+""""""""
+
+The `cuSOLVER `_ library.
+
+Targets Created:
+
+- ``CUDA::cusolver``
+- ``CUDA::cusolver_static``
+
+.. _`cuda_toolkit_cuSPARSE`:
+
+cuSPARSE
+""""""""
+
+The `cuSPARSE `_ library.
+
+Targets Created:
+
+- ``CUDA::cusparse``
+- ``CUDA::cusparse_static``
+
+.. _`cuda_toolkit_cupti`:
+
+cupti
+"""""
+
+The `NVIDIA CUDA Profiling Tools Interface `_.
+
+Targets Created:
+
+- ``CUDA::cupti``
+- ``CUDA::cupti_static``
+
+.. _`cuda_toolkit_NPP`:
+
+NPP
+"""
+
+The `NPP `_ libraries.
+
+Targets Created:
+
+- `nppc`:
+
+  - ``CUDA::nppc``
+  - ``CUDA::nppc_static``
+
+- `nppial`: Arithmetic and logical operation functions in `nppi_arithmetic_and_logical_operations.h`
+
+  - ``CUDA::nppial``
+  - ``CUDA::nppial_static``
+
+- `nppicc`: Color conversion and sampling functions in `nppi_color_conversion.h`
+
+  - ``CUDA::nppicc``
+  - ``CUDA::nppicc_static``
+
+- `nppicom`: JPEG compression and decompression functions in `nppi_compression_functions.h`
+  Removed starting in CUDA 11.0, use :ref:`nvJPEG` instead.
+
+  - ``CUDA::nppicom``
+  - ``CUDA::nppicom_static``
+
+- `nppidei`: Data exchange and initialization functions in `nppi_data_exchange_and_initialization.h`
+
+  - ``CUDA::nppidei``
+  - ``CUDA::nppidei_static``
+
+- `nppif`: Filtering and computer vision functions in `nppi_filter_functions.h`
+
+  - ``CUDA::nppif``
+  - ``CUDA::nppif_static``
+
+- `nppig`: Geometry transformation functions found in `nppi_geometry_transforms.h`
+
+  - ``CUDA::nppig``
+  - ``CUDA::nppig_static``
+
+- `nppim`: Morphological operation functions found in `nppi_morphological_operations.h`
+
+  - ``CUDA::nppim``
+  - ``CUDA::nppim_static``
+
+- `nppist`: Statistics and linear transform in `nppi_statistics_functions.h` and `nppi_linear_transforms.h`
+
+  - ``CUDA::nppist``
+  - ``CUDA::nppist_static``
+
+- `nppisu`: Memory support functions in `nppi_support_functions.h`
+
+  - ``CUDA::nppisu``
+  - ``CUDA::nppisu_static``
+
+- `nppitc`: Threshold and compare operation functions in `nppi_threshold_and_compare_operations.h`
+
+  - ``CUDA::nppitc``
+  - ``CUDA::nppitc_static``
+
+- `npps`:
+
+  - ``CUDA::npps``
+  - ``CUDA::npps_static``
+
+.. _`cuda_toolkit_nvBLAS`:
+
+nvBLAS
+""""""
+
+The `nvBLAS `_ libraries.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvblas``
+
+.. _`cuda_toolkit_nvGRAPH`:
+
+nvGRAPH
+"""""""
+
+The `nvGRAPH `_ library.
+Removed starting in CUDA 11.0
+
+Targets Created:
+
+- ``CUDA::nvgraph``
+- ``CUDA::nvgraph_static``
+
+
+.. _`cuda_toolkit_nvJPEG`:
+
+nvJPEG
+""""""
+
+The `nvJPEG `_ library.
+Introduced in CUDA 10.
+
+Targets Created:
+
+- ``CUDA::nvjpeg``
+- ``CUDA::nvjpeg_static``
+
+.. _`cuda_toolkit_nvRTC`:
+
+nvRTC
+"""""
+
+The `nvRTC `_ (Runtime Compilation) library.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvrtc``
+
+.. _`cuda_toolkit_nvml`:
+
+nvidia-ML
+"""""""""
+
+The `NVIDIA Management Library `_.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvml``
+
+.. _`cuda_toolkit_nvToolsExt`:
+
+nvToolsExt
+""""""""""
+
+The `NVIDIA Tools Extension `_.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvToolsExt``
+
+.. _`cuda_toolkit_opencl`:
+
+OpenCL
+""""""
+
+The `NVIDIA OpenCL Library `_.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::OpenCL``
+
+.. _`cuda_toolkit_cuLIBOS`:
+
+cuLIBOS
+"""""""
+
+The cuLIBOS library is a backend thread abstraction layer library which is
+static only.  The ``CUDA::cublas_static``, ``CUDA::cusparse_static``,
+``CUDA::cufft_static``, ``CUDA::curand_static``, and (when implemented) NPP
+libraries all automatically have this dependency linked.
+
+Target Created:
+
+- ``CUDA::culibos``
+
+**Note**: direct usage of this target by consumers should not be necessary.
+
+.. _`cuda_toolkit_cuRAND`:
+
+
+
+Result variables
+^^^^^^^^^^^^^^^^
+
+``CUDAToolkit_FOUND``
+    A boolean specifying whether or not the CUDA Toolkit was found.
+
+``CUDAToolkit_VERSION``
+    The exact version of the CUDA Toolkit found (as reported by
+    ``nvcc --version`` or ``version.txt``).
+
+``CUDAToolkit_VERSION_MAJOR``
+    The major version of the CUDA Toolkit.
+
+``CUDAToolkit_VERSION_MINOR``
+    The minor version of the CUDA Toolkit.
+
+``CUDAToolkit_VERSION_PATCH``
+    The patch version of the CUDA Toolkit.
+
+``CUDAToolkit_BIN_DIR``
+    The path to the CUDA Toolkit library directory that contains the CUDA
+    executable ``nvcc``.
+
+``CUDAToolkit_INCLUDE_DIRS``
+    The path to the CUDA Toolkit ``include`` folder containing the header files
+    required to compile a project linking against CUDA.
+
+``CUDAToolkit_LIBRARY_DIR``
+    The path to the CUDA Toolkit library directory that contains the CUDA
+    Runtime library ``cudart``.
+
+``CUDAToolkit_LIBRARY_ROOT``
+    .. versionadded:: 3.18
+
+    The path to the CUDA Toolkit directory containing the nvvm directory and
+    version.txt.
+
+``CUDAToolkit_TARGET_DIR``
+    The path to the CUDA Toolkit directory including the target architecture
+    when cross-compiling. When not cross-compiling this will be equivalent to
+    the parent directory of ``CUDAToolkit_BIN_DIR``.
+
+``CUDAToolkit_NVCC_EXECUTABLE``
+    The path to the NVIDIA CUDA compiler ``nvcc``.  Note that this path may
+    **not** be the same as
+    :variable:`CMAKE_CUDA_COMPILER _COMPILER>`.  ``nvcc`` must be
+    found to determine the CUDA Toolkit version as well as determining other
+    features of the Toolkit.  This variable is set for the convenience of
+    modules that depend on this one.
+
+
+#]=======================================================================]
+
+# NOTE: much of this was simply extracted from FindCUDA.cmake.
+
+#   James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#   Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#   Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#   Copyright (c) 2007-2009
+#   Scientific Computing and Imaging Institute, University of Utah
+#
+#   This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#   for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+###############################################################################
+
+# The toolkit is located during compiler detection for CUDA and stored in CMakeCUDACompiler.cmake as
+# CMAKE_CUDA_COMPILER_TOOLKIT_ROOT and CMAKE_CUDA_COMPILER_LIBRARY_ROOT.
+# We compute the rest based on those here to avoid re-searching and to avoid finding a possibly
+# different installation.
+if(CMAKE_CUDA_COMPILER_TOOLKIT_ROOT)
+  set(CUDAToolkit_ROOT_DIR "${CMAKE_CUDA_COMPILER_TOOLKIT_ROOT}")
+  set(CUDAToolkit_LIBRARY_ROOT "${CMAKE_CUDA_COMPILER_LIBRARY_ROOT}")
+  set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_TOOLKIT_VERSION}")
+
+  if(CUDAToolkit_VERSION MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+    set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+    set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+    set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+  endif()
+else()
+  function(_CUDAToolkit_find_root_dir )
+    cmake_parse_arguments(arg "" "" "SEARCH_PATHS;FIND_FLAGS" ${ARGN})
+
+    if(NOT CUDAToolkit_BIN_DIR)
+      if(NOT CUDAToolkit_SENTINEL_FILE)
+        find_program(CUDAToolkit_NVCC_EXECUTABLE
+          NAMES nvcc nvcc.exe
+          PATHS ${arg_SEARCH_PATHS}
+          ${arg_FIND_FLAGS}
+        )
+      endif()
+
+      if(NOT CUDAToolkit_NVCC_EXECUTABLE)
+        find_file(CUDAToolkit_SENTINEL_FILE
+          NAMES version.txt
+          PATHS ${arg_SEARCH_PATHS}
+          NO_DEFAULT_PATH
+        )
+      endif()
+
+      if(EXISTS "${CUDAToolkit_NVCC_EXECUTABLE}")
+        # If NVCC exists  then invoke it to find the toolkit location.
+        # This allows us to support wrapper scripts (e.g. ccache or colornvcc), CUDA Toolkit,
+        # NVIDIA HPC SDK, and distro's splayed layouts
+        execute_process(COMMAND ${CUDAToolkit_NVCC_EXECUTABLE} "-v" "__cmake_determine_cuda"
+          OUTPUT_VARIABLE _CUDA_NVCC_OUT ERROR_VARIABLE _CUDA_NVCC_OUT)
+        if(_CUDA_NVCC_OUT MATCHES "\\#\\$ TOP=([^\r\n]*)")
+          get_filename_component(CUDAToolkit_BIN_DIR "${CMAKE_MATCH_1}/bin" ABSOLUTE)
+        else()
+          get_filename_component(CUDAToolkit_BIN_DIR "${CUDAToolkit_NVCC_EXECUTABLE}" DIRECTORY)
+        endif()
+        unset(_CUDA_NVCC_OUT)
+
+        mark_as_advanced(CUDAToolkit_BIN_DIR)
+        set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "" FORCE)
+      endif()
+
+      if(CUDAToolkit_SENTINEL_FILE)
+        get_filename_component(CUDAToolkit_BIN_DIR ${CUDAToolkit_SENTINEL_FILE} DIRECTORY ABSOLUTE)
+        set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}/bin")
+
+        set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "" FORCE)
+        mark_as_advanced(CUDAToolkit_BIN_DIR)
+      endif()
+    endif()
+
+    if(CUDAToolkit_BIN_DIR)
+      get_filename_component(CUDAToolkit_ROOT_DIR ${CUDAToolkit_BIN_DIR} DIRECTORY ABSOLUTE)
+      set(CUDAToolkit_ROOT_DIR "${CUDAToolkit_ROOT_DIR}" PARENT_SCOPE)
+    endif()
+
+  endfunction()
+
+  # For NVCC we can easily deduce the SDK binary directory from the compiler path.
+  if(CMAKE_CUDA_COMPILER_LOADED AND NOT CUDAToolkit_BIN_DIR AND CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA")
+    get_filename_component(CUDAToolkit_BIN_DIR "${CMAKE_CUDA_COMPILER}" DIRECTORY)
+    set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "")
+    # Try language provided path first.
+    _CUDAToolkit_find_root_dir(SEARCH_PATHS "${CUDAToolkit_BIN_DIR}" FIND_FLAGS NO_DEFAULT_PATH)
+    mark_as_advanced(CUDAToolkit_BIN_DIR)
+  endif()
+
+  # Try user provided path
+  if(NOT CUDAToolkit_ROOT_DIR AND CUDAToolkit_ROOT)
+    _CUDAToolkit_find_root_dir(SEARCH_PATHS "${CUDAToolkit_ROOT}" FIND_FLAGS PATH_SUFFIXES bin NO_DEFAULT_PATH)
+  endif()
+  if(NOT CUDAToolkit_ROOT_DIR)
+    _CUDAToolkit_find_root_dir(FIND_FLAGS PATHS ENV CUDA_PATH PATH_SUFFIXES bin)
+  endif()
+
+  # If the user specified CUDAToolkit_ROOT but the toolkit could not be found, this is an error.
+  if(NOT CUDAToolkit_ROOT_DIR AND (DEFINED CUDAToolkit_ROOT OR DEFINED ENV{CUDAToolkit_ROOT}))
+    # Declare error messages now, print later depending on find_package args.
+    set(fail_base "Could not find nvcc executable in path specified by")
+    set(cuda_root_fail "${fail_base} CUDAToolkit_ROOT=${CUDAToolkit_ROOT}")
+    set(env_cuda_root_fail "${fail_base} environment variable CUDAToolkit_ROOT=$ENV{CUDAToolkit_ROOT}")
+
+    if(CUDAToolkit_FIND_REQUIRED)
+      if(DEFINED CUDAToolkit_ROOT)
+        message(FATAL_ERROR ${cuda_root_fail})
+      elseif(DEFINED ENV{CUDAToolkit_ROOT})
+        message(FATAL_ERROR ${env_cuda_root_fail})
+      endif()
+    else()
+      if(NOT CUDAToolkit_FIND_QUIETLY)
+        if(DEFINED CUDAToolkit_ROOT)
+          message(STATUS ${cuda_root_fail})
+        elseif(DEFINED ENV{CUDAToolkit_ROOT})
+          message(STATUS ${env_cuda_root_fail})
+        endif()
+      endif()
+      set(CUDAToolkit_FOUND FALSE)
+      unset(fail_base)
+      unset(cuda_root_fail)
+      unset(env_cuda_root_fail)
+      return()
+    endif()
+  endif()
+
+  # CUDAToolkit_ROOT cmake / env variable not specified, try platform defaults.
+  #
+  # - Linux: /usr/local/cuda-X.Y
+  # - macOS: /Developer/NVIDIA/CUDA-X.Y
+  # - Windows: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y
+  #
+  # We will also search the default symlink location /usr/local/cuda first since
+  # if CUDAToolkit_ROOT is not specified, it is assumed that the symlinked
+  # directory is the desired location.
+  if(NOT CUDAToolkit_ROOT_DIR)
+    if(UNIX)
+      if(NOT APPLE)
+        set(platform_base "/usr/local/cuda-")
+      else()
+        set(platform_base "/Developer/NVIDIA/CUDA-")
+      endif()
+    else()
+      set(platform_base "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v")
+    endif()
+
+    # Build out a descending list of possible cuda installations, e.g.
+    file(GLOB possible_paths "${platform_base}*")
+    # Iterate the glob results and create a descending list.
+    set(versions)
+    foreach(p ${possible_paths})
+      # Extract version number from end of string
+      string(REGEX MATCH "[0-9][0-9]?\\.[0-9]$" p_version ${p})
+      if(IS_DIRECTORY ${p} AND p_version)
+        list(APPEND versions ${p_version})
+      endif()
+    endforeach()
+
+    # Sort numerically in descending order, so we try the newest versions first.
+    if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
+      list(SORT versions COMPARE NATURAL ORDER DESCENDING)
+    elseif(versions)
+      # Alphabetical sort here is not ideal but better than nothing
+      list(SORT versions)
+      list(REVERSE versions)
+    endif()
+
+    # With a descending list of versions, populate possible paths to search.
+    set(search_paths)
+    foreach(v ${versions})
+      list(APPEND search_paths "${platform_base}${v}")
+    endforeach()
+
+    # Force the global default /usr/local/cuda to the front on Unix.
+    if(UNIX)
+      list(INSERT search_paths 0 "/usr/local/cuda")
+    endif()
+
+    # Now search for the toolkit again using the platform default search paths.
+    _CUDAToolkit_find_root_dir(SEARCH_PATHS "${search_paths}" FIND_FLAGS PATH_SUFFIXES bin)
+
+    # We are done with these variables now, cleanup for caller.
+    unset(platform_base)
+    unset(possible_paths)
+    unset(versions)
+    unset(search_paths)
+
+    if(NOT CUDAToolkit_ROOT_DIR)
+      if(CUDAToolkit_FIND_REQUIRED)
+        message(FATAL_ERROR "Could not find nvcc, please set CUDAToolkit_ROOT.")
+      elseif(NOT CUDAToolkit_FIND_QUIETLY)
+        message(STATUS "Could not find nvcc, please set CUDAToolkit_ROOT.")
+      endif()
+
+      set(CUDAToolkit_FOUND FALSE)
+      return()
+    endif()
+  endif()
+endif()
+
+if(NOT CUDAToolkit_BIN_DIR)
+  set(CUDAToolkit_BIN_DIR "${CUDAToolkit_ROOT_DIR}/bin")
+endif()
+
+if(NOT CUDAToolkit_NVCC_EXECUTABLE)
+  set(CUDAToolkit_NVCC_EXECUTABLE "${CUDAToolkit_BIN_DIR}/nvcc${CMAKE_EXECUTABLE_SUFFIX}")
+endif()
+
+if(CMAKE_CUDA_COMPILER_TOOLKIT_VERSION)
+  set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_TOOLKIT_VERSION}")
+else()
+  function(_CUDAToolkit_find_version_file result_variable)
+    # We first check for a non-scattered installation to prefer it over a scattered installation.
+    if(CUDAToolkit_ROOT AND EXISTS "${CUDAToolkit_ROOT}/version.txt")
+      set(${result_variable} "${CUDAToolkit_ROOT}/version.txt" PARENT_SCOPE)
+    elseif(CUDAToolkit_ROOT_DIR AND EXISTS "${CUDAToolkit_ROOT_DIR}/version.txt")
+      set(${result_variable} "${CUDAToolkit_ROOT_DIR}/version.txt" PARENT_SCOPE)
+    elseif(CMAKE_SYSROOT_LINK AND EXISTS "${CMAKE_SYSROOT_LINK}/usr/lib/cuda/version.txt")
+      set(${result_variable} "${CMAKE_SYSROOT_LINK}/usr/lib/cuda/version.txt" PARENT_SCOPE)
+    elseif(EXISTS "${CMAKE_SYSROOT}/usr/lib/cuda/version.txt")
+      set(${result_variable} "${CMAKE_SYSROOT}/usr/lib/cuda/version.txt" PARENT_SCOPE)
+    endif()
+  endfunction()
+
+  _CUDAToolkit_find_version_file( _CUDAToolkit_version_file )
+  if(_CUDAToolkit_version_file)
+    # CUDAToolkit_LIBRARY_ROOT contains the device library and version file.
+    get_filename_component(CUDAToolkit_LIBRARY_ROOT "${_CUDAToolkit_version_file}" DIRECTORY ABSOLUTE)
+  endif()
+  unset(_CUDAToolkit_version_file)
+
+  if(CUDAToolkit_NVCC_EXECUTABLE AND
+     CMAKE_CUDA_COMPILER_VERSION AND
+     CUDAToolkit_NVCC_EXECUTABLE STREQUAL CMAKE_CUDA_COMPILER)
+    # Need to set these based off the already computed CMAKE_CUDA_COMPILER_VERSION value
+    # This if statement will always match, but is used to provide variables for MATCH 1,2,3...
+    if(CMAKE_CUDA_COMPILER_VERSION MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+      set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+      set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+      set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+      set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_VERSION}")
+    endif()
+  elseif(CUDAToolkit_NVCC_EXECUTABLE)
+    # Compute the version by invoking nvcc
+    execute_process(COMMAND ${CUDAToolkit_NVCC_EXECUTABLE} "--version" OUTPUT_VARIABLE NVCC_OUT)
+    if(NVCC_OUT MATCHES [=[ V([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+      set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+      set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+      set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+      set(CUDAToolkit_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}")
+    endif()
+    unset(NVCC_OUT)
+  else()
+    _CUDAToolkit_find_version_file(version_file)
+    if(version_file)
+      file(READ "${version_file}" VERSION_INFO)
+      if(VERSION_INFO MATCHES [=[CUDA Version ([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+        set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+        set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+        set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+        set(CUDAToolkit_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}")
+      endif()
+    endif()
+  endif()
+endif()
+
+# Find target directory when crosscompiling.
+if(CMAKE_CROSSCOMPILING)
+  if(CMAKE_SYSTEM_PROCESSOR STREQUAL "armv7-a")
+    # Support for NVPACK
+    set(CUDAToolkit_TARGET_NAME "armv7-linux-androideabi")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm")
+    set(CUDAToolkit_TARGET_NAME "armv7-linux-gnueabihf")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
+    if(ANDROID_ARCH_NAME STREQUAL "arm64")
+      set(CUDAToolkit_TARGET_NAME "aarch64-linux-androideabi")
+    elseif(CMAKE_SYSTEM_NAME STREQUAL "QNX")
+      set(CUDAToolkit_TARGET_NAME "aarch64-qnx")
+    else()
+      set(CUDAToolkit_TARGET_NAME "aarch64-linux")
+    endif(ANDROID_ARCH_NAME STREQUAL "arm64")
+  elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
+    set(CUDAToolkit_TARGET_NAME "x86_64-linux")
+  endif()
+
+  if(EXISTS "${CUDAToolkit_ROOT_DIR}/targets/${CUDAToolkit_TARGET_NAME}")
+    set(CUDAToolkit_TARGET_DIR "${CUDAToolkit_ROOT_DIR}/targets/${CUDAToolkit_TARGET_NAME}")
+    # add known CUDA target root path to the set of directories we search for programs, libraries and headers
+    list(PREPEND CMAKE_FIND_ROOT_PATH "${CUDAToolkit_TARGET_DIR}")
+
+    # Mark that we need to pop the root search path changes after we have
+    # found all cuda libraries so that searches for our cross-compilation
+    # libraries work when another cuda sdk is in CMAKE_PREFIX_PATH or
+    # PATh
+    set(_CUDAToolkit_Pop_ROOT_PATH True)
+  endif()
+endif()
+
+# If not already set we can simply use the toolkit root or it's a scattered installation.
+if(NOT CUDAToolkit_TARGET_DIR)
+  # Not cross compiling
+  set(CUDAToolkit_TARGET_DIR "${CUDAToolkit_ROOT_DIR}")
+  # Now that we have the real ROOT_DIR, find components inside it.
+  list(APPEND CMAKE_PREFIX_PATH ${CUDAToolkit_ROOT_DIR})
+
+  # Mark that we need to pop the prefix path changes after we have
+  # found the cudart library.
+  set(_CUDAToolkit_Pop_Prefix True)
+endif()
+
+# CUDAToolkit_TARGET_DIR always points to the directory containing the include directory.
+# On a scattered installation /usr, on a non-scattered something like /usr/local/cuda or /usr/local/cuda-10.2/targets/aarch64-linux.
+if(EXISTS "${CUDAToolkit_TARGET_DIR}/include/cuda_runtime.h")
+  set(CUDAToolkit_INCLUDE_DIR "${CUDAToolkit_TARGET_DIR}/include")
+elseif(NOT CUDAToolkit_FIND_QUIETLY)
+  message(STATUS "Unable to find cuda_runtime.h in \"${CUDAToolkit_TARGET_DIR}/include\" for CUDAToolkit_INCLUDE_DIR.")
+endif()
+
+# The NVHPC layout moves math library headers and libraries to a sibling directory.
+# Create a separate variable so this directory can be selectively added to math targets.
+if(NOT EXISTS "${CUDAToolkit_INCLUDE_DIR}/cublas_v2.h")
+  set(CUDAToolkit_MATH_INCLUDE_DIR "${CUDAToolkit_TARGET_DIR}/../../math_libs/include")
+  get_filename_component(CUDAToolkit_MATH_INCLUDE_DIR "${CUDAToolkit_MATH_INCLUDE_DIR}" ABSOLUTE)
+  if(NOT EXISTS "${CUDAToolkit_MATH_INCLUDE_DIR}/cublas_v2.h")
+    if(NOT CUDAToolkit_FIND_QUIETLY)
+      message(STATUS "Unable to find cublas_v2.h in either \"${CUDAToolkit_INCLUDE_DIR}\" or \"${CUDAToolkit_MATH_INCLUDE_DIR}\"")
+    endif()
+    unset(CUDAToolkit_MATH_INCLUDE_DIR)
+  endif()
+endif()
+
+# Find the CUDA Runtime Library libcudart
+find_library(CUDA_CUDART
+  NAMES cudart
+  PATH_SUFFIXES lib64 lib/x64
+)
+find_library(CUDA_CUDART
+  NAMES cudart
+  PATH_SUFFIXES lib64/stubs lib/x64/stubs
+)
+
+if(NOT CUDA_CUDART AND NOT CUDAToolkit_FIND_QUIETLY)
+  message(STATUS "Unable to find cudart library.")
+endif()
+
+if(_CUDAToolkit_Pop_Prefix)
+  list(REMOVE_AT CMAKE_PREFIX_PATH -1)
+  unset(_CUDAToolkit_Pop_Prefix)
+endif()
+
+#-----------------------------------------------------------------------------
+# Perform version comparison and validate all required variables are set.
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(CUDAToolkit
+  REQUIRED_VARS
+    CUDAToolkit_INCLUDE_DIR
+    CUDAToolkit_VERSION
+    CUDA_CUDART
+    CUDAToolkit_BIN_DIR
+  VERSION_VAR
+    CUDAToolkit_VERSION
+)
+
+mark_as_advanced(CUDA_CUDART
+                 CUDAToolkit_INCLUDE_DIR
+                 CUDAToolkit_NVCC_EXECUTABLE
+                 CUDAToolkit_SENTINEL_FILE
+                 )
+
+#-----------------------------------------------------------------------------
+# Construct result variables
+if(CUDAToolkit_FOUND)
+  set(CUDAToolkit_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIR})
+  get_filename_component(CUDAToolkit_LIBRARY_DIR ${CUDA_CUDART} DIRECTORY ABSOLUTE)
+endif()
+
+#-----------------------------------------------------------------------------
+# Construct import targets
+if(CUDAToolkit_FOUND)
+
+  function(_CUDAToolkit_find_and_add_import_lib lib_name)
+    cmake_parse_arguments(arg "" "" "ALT;DEPS;EXTRA_HINTS;EXTRA_PATH_SUFFIXES;EXTRA_INCLUDE_DIRS" ${ARGN})
+
+    set(search_names ${lib_name} ${arg_ALT})
+
+    find_library(CUDA_${lib_name}_LIBRARY
+      NAMES ${search_names}
+      HINTS ${CUDAToolkit_LIBRARY_DIR}
+            ENV CUDA_PATH
+            ${arg_EXTRA_HINTS}
+      PATH_SUFFIXES nvidia/current lib64 lib/x64 lib
+                    ${arg_EXTRA_PATH_SUFFIXES}
+    )
+    # Don't try any stub directories until we have exhausted all other
+    # search locations.
+    find_library(CUDA_${lib_name}_LIBRARY
+      NAMES ${search_names}
+      HINTS ${CUDAToolkit_LIBRARY_DIR}
+            ENV CUDA_PATH
+            ${arg_EXTRA_HINTS}
+      PATH_SUFFIXES lib64/stubs lib/x64/stubs lib/stubs stubs
+                    # Support NVHPC splayed math library layout
+                    ../../math_libs/${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}/lib64
+                    ../../math_libs/lib64
+    )
+
+    mark_as_advanced(CUDA_${lib_name}_LIBRARY)
+
+    if(NOT TARGET CUDA::${lib_name} AND CUDA_${lib_name}_LIBRARY)
+      add_library(CUDA::${lib_name} UNKNOWN IMPORTED)
+      set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+          INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+      set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+          INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+      if(DEFINED CUDAToolkit_MATH_INCLUDE_DIR)
+        string(FIND ${CUDA_${lib_name}_LIBRARY} "math_libs" math_libs)
+        if(NOT ${math_libs} EQUAL -1)
+          set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+              INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_MATH_INCLUDE_DIRS}")
+          set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+              INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_MATH_INCLUDE_DIRS}")
+        endif()
+      endif()
+      set_property(TARGET CUDA::${lib_name} PROPERTY IMPORTED_LOCATION "${CUDA_${lib_name}_LIBRARY}")
+      foreach(dep ${arg_DEPS})
+        if(TARGET CUDA::${dep})
+          set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+              INTERFACE_LINK_LIBRARIES CUDA::${dep})
+        endif()
+      endforeach()
+      if(arg_EXTRA_INCLUDE_DIRS)
+        set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+            INTERFACE_INCLUDE_DIRECTORIES "${arg_EXTRA_INCLUDE_DIRS}")
+        set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+            INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${arg_EXTRA_INCLUDE_DIRS}")
+      endif()
+    endif()
+  endfunction()
+
+  if(NOT TARGET CUDA::toolkit)
+    add_library(CUDA::toolkit IMPORTED INTERFACE)
+    set_property(TARGET CUDA::toolkit APPEND PROPERTY
+        INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+    set_property(TARGET CUDA::toolkit APPEND PROPERTY
+        INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(cuda_driver ALT cuda)
+
+  _CUDAToolkit_find_and_add_import_lib(cudart)
+  _CUDAToolkit_find_and_add_import_lib(cudart_static)
+
+  # setup dependencies that are required for cudart_static when building
+  # on linux. These are generally only required when using the CUDA toolkit
+  # when CUDA language is disabled
+  if(NOT TARGET CUDA::cudart_static_deps
+     AND TARGET CUDA::cudart_static)
+
+    add_library(CUDA::cudart_static_deps IMPORTED INTERFACE)
+    set_property(TARGET CUDA::cudart_static APPEND PROPERTY
+        INTERFACE_LINK_LIBRARIES CUDA::cudart_static_deps)
+
+    if(UNIX AND (CMAKE_C_COMPILER OR CMAKE_CXX_COMPILER))
+      find_package(Threads REQUIRED)
+      set_property(TARGET CUDA::cudart_static_deps APPEND PROPERTY
+          INTERFACE_LINK_LIBRARIES Threads::Threads ${CMAKE_DL_LIBS})
+    endif()
+
+    if(UNIX AND NOT APPLE AND NOT (CMAKE_SYSTEM_NAME STREQUAL "QNX"))
+      # On Linux, you must link against librt when using the static cuda runtime.
+      find_library(CUDAToolkit_rt_LIBRARY rt)
+      mark_as_advanced(CUDAToolkit_rt_LIBRARY)
+      if(NOT CUDAToolkit_rt_LIBRARY)
+        message(WARNING "Could not find librt library, needed by CUDA::cudart_static")
+      else()
+        set_property(TARGET CUDA::cudart_static_deps APPEND PROPERTY
+            INTERFACE_LINK_LIBRARIES ${CUDAToolkit_rt_LIBRARY})
+      endif()
+    endif()
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(culibos) # it's a static library
+  foreach(cuda_lib cublasLt cufft curand cusparse nppc nvjpeg)
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib})
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib}_static DEPS culibos)
+  endforeach()
+
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.0.0)
+    # cublas depends on cublasLt
+    # https://docs.nvidia.com/cuda/archive/11.0/cublas/index.html#static-library
+    _CUDAToolkit_find_and_add_import_lib(cublas DEPS cublasLt)
+    _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS cublasLt_static)
+  else()
+    _CUDAToolkit_find_and_add_import_lib(cublas)
+    _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos)
+  endif()
+
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4)
+    _CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos)
+    _CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos)
+
+    _CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos)
+    _CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos)
+  endif()
+
+  # cuFFTW depends on cuFFT
+  _CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft)
+  _CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static)
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 9.2)
+    _CUDAToolkit_find_and_add_import_lib(cufft_static_nocallback DEPS culibos)
+  endif()
+
+  # cuSOLVER depends on cuBLAS, and cuSPARSE
+  _CUDAToolkit_find_and_add_import_lib(cusolver DEPS cublas cusparse)
+  _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cublas_static cusparse_static culibos)
+
+
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 10.1.2)
+    # cusolver depends on liblapack_static.a starting with CUDA 10.1 update 2,
+    # https://docs.nvidia.com/cuda/archive/11.5.0/cusolver/index.html#static-link-lapack
+    _CUDAToolkit_find_and_add_import_lib(cusolver_lapack_static ALT lapack_static) # implementation detail static lib
+    _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cusolver_lapack_static)
+  endif()
+
+  if(CUDAToolkit_VERSION VERSION_GREATER 11.2.1)
+    # cusolver depends on libcusolver_metis and cublasLt
+    # https://docs.nvidia.com/cuda/archive/11.2.2/cusolver/index.html#link-dependency
+    _CUDAToolkit_find_and_add_import_lib(cusolver DEPS cublasLt)
+
+    _CUDAToolkit_find_and_add_import_lib(cusolver_metis_static ALT metis_static) # implementation detail static lib
+    _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cusolver_metis_static cublasLt_static)
+  endif()
+
+  # nvGRAPH depends on cuRAND, and cuSOLVER.
+  _CUDAToolkit_find_and_add_import_lib(nvgraph DEPS curand cusolver)
+  _CUDAToolkit_find_and_add_import_lib(nvgraph_static DEPS curand_static cusolver_static)
+
+  # Process the majority of the NPP libraries.
+  foreach(cuda_lib nppial nppicc nppidei nppif nppig nppim nppist nppitc npps nppicom nppisu)
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib} DEPS nppc)
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib}_static DEPS nppc_static)
+  endforeach()
+
+  find_path(CUDAToolkit_CUPTI_INCLUDE_DIR cupti.h PATHS
+      "${CUDAToolkit_ROOT_DIR}/extras/CUPTI/include"
+      "${CUDAToolkit_INCLUDE_DIR}/../extras/CUPTI/include"
+      "${CUDAToolkit_INCLUDE_DIR}"
+      NO_DEFAULT_PATH)
+  mark_as_advanced(CUDAToolkit_CUPTI_INCLUDE_DIR)
+
+  if(CUDAToolkit_CUPTI_INCLUDE_DIR)
+    _CUDAToolkit_find_and_add_import_lib(cupti
+                                        EXTRA_PATH_SUFFIXES ../extras/CUPTI/lib64/
+                                                            ../extras/CUPTI/lib/
+                                        EXTRA_INCLUDE_DIRS "${CUDAToolkit_CUPTI_INCLUDE_DIR}")
+    _CUDAToolkit_find_and_add_import_lib(cupti_static
+                                        EXTRA_PATH_SUFFIXES ../extras/CUPTI/lib64/
+                                                            ../extras/CUPTI/lib/
+                                        EXTRA_INCLUDE_DIRS "${CUDAToolkit_CUPTI_INCLUDE_DIR}")
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(nvrtc DEPS cuda_driver)
+
+  _CUDAToolkit_find_and_add_import_lib(nvml ALT nvidia-ml nvml)
+
+  # nvtools can be installed outside the CUDA toolkit directory,
+  # so search the NVTOOLSEXT_PATH windows only environment variable
+  set(nvToolsExt_EXTRA_PATH)
+  if(WIN32)
+     set(nvToolsExt_EXTRA_PATH "C:\\Program Files\\NVIDIA Corporation\\NvToolsExt")
+  endif()
+
+  find_path(CUDAToolkit_nvToolsExt_INCLUDE_DIR nvToolsExt.h
+      PATHS "${CUDAToolkit_INCLUDE_DIR}"
+            "${CUDAToolkit_ROOT_DIR}"
+            ENV NVTOOLSEXT_PATH
+            "${nvToolsExt_EXTRA_PATH}"
+      PATH_SUFFIXES include
+      NO_DEFAULT_PATH)
+  mark_as_advanced(CUDAToolkit_nvToolsExt_INCLUDE_DIR)
+
+  if(CUDAToolkit_nvToolsExt_INCLUDE_DIR)
+    _CUDAToolkit_find_and_add_import_lib(nvToolsExt
+        ALT nvToolsExt64 nvToolsExt64_1
+        EXTRA_HINTS ENV NVTOOLSEXT_PATH
+                    "${nvToolsExt_EXTRA_PATH}"
+        EXTRA_INCLUDE_DIRS "${CUDAToolkit_nvToolsExt_INCLUDE_DIR}")
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(OpenCL)
+endif()
+
+unset(CUDAToolkit_ROOT_DIR)
+
+if(_CUDAToolkit_Pop_ROOT_PATH)
+  list(REMOVE_AT CMAKE_FIND_ROOT_PATH 0)
+  unset(_CUDAToolkit_Pop_ROOT_PATH)
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..b614e1c492b99f7b3adf456b0b88bdf5cd26fd0b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake
@@ -0,0 +1,67 @@
+# Find the CUDSS library
+#
+# The following variables are optionally searched for defaults
+#  CUDSS_ROOT: Base directory where CUDSS is found
+#  CUDSS_INCLUDE_DIR: Directory where CUDSS header is searched for
+#  CUDSS_LIBRARY: Directory where CUDSS library is searched for
+#
+# The following are set after configuration is done:
+#  CUDSS_FOUND
+#  CUDSS_INCLUDE_PATH
+#  CUDSS_LIBRARY_PATH
+
+include(FindPackageHandleStandardArgs)
+
+set(CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} CACHE PATH "Folder containing NVIDIA CUDSS")
+if (DEFINED $ENV{CUDSS_ROOT_DIR})
+  message(WARNING "CUDSS_ROOT_DIR is deprecated. Please set CUDSS_ROOT instead.")
+endif()
+list(APPEND CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+
+# Compatible layer for CMake <3.12. CUDSS_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+list(APPEND CMAKE_PREFIX_PATH ${CUDSS_ROOT})
+
+set(CUDSS_INCLUDE_DIR $ENV{CUDSS_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA CUDSS header files")
+
+find_path(CUDSS_INCLUDE_PATH cudss.h
+  HINTS ${CUDSS_INCLUDE_DIR}
+  PATH_SUFFIXES cuda/include cuda include)
+
+set(CUDSS_LIBRARY $ENV{CUDSS_LIBRARY} CACHE PATH "Path to the CUDSS library file (e.g., libcudss.so)")
+
+set(CUDSS_LIBRARY_NAME "libcudss.so")
+if(MSVC)
+  set(CUDSS_LIBRARY_NAME "cudss.lib")
+endif()
+
+find_library(CUDSS_LIBRARY_PATH ${CUDSS_LIBRARY_NAME}
+  PATHS ${CUDSS_LIBRARY}
+  PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
+
+find_package_handle_standard_args(CUDSS DEFAULT_MSG CUDSS_LIBRARY_PATH CUDSS_INCLUDE_PATH)
+
+if(CUDSS_FOUND)
+  # Get CUDSS version
+  file(READ ${CUDSS_INCLUDE_PATH}/cudss.h CUDSS_HEADER_CONTENTS)
+  string(REGEX MATCH "define CUDSS_VER_MAJOR * +([0-9]+)"
+               CUDSS_VERSION_MAJOR "${CUDSS_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDSS_VER_MAJOR * +([0-9]+)" "\\1"
+               CUDSS_VERSION_MAJOR "${CUDSS_VERSION_MAJOR}")
+  string(REGEX MATCH "define CUDSS_VER_MINOR * +([0-9]+)"
+               CUDSS_VERSION_MINOR "${CUDSS_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDSS_VER_MINOR * +([0-9]+)" "\\1"
+               CUDSS_VERSION_MINOR "${CUDSS_VERSION_MINOR}")
+  string(REGEX MATCH "define CUDSS_VER_PATCH * +([0-9]+)"
+               CUDSS_VERSION_PATCH "${CUDSS_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDSS_VER_PATCH * +([0-9]+)" "\\1"
+               CUDSS_VERSION_PATCH "${CUDSS_VERSION_PATCH}")
+  # Assemble CUDSS version. Use minor version since current major version is 0.
+  if(NOT CUDSS_VERSION_MINOR)
+    set(CUDSS_VERSION "?")
+  else()
+    set(CUDSS_VERSION
+        "${CUDSS_VERSION_MAJOR}.${CUDSS_VERSION_MINOR}.${CUDSS_VERSION_PATCH}")
+  endif()
+endif()
+
+mark_as_advanced(CUDSS_ROOT CUDSS_INCLUDE_DIR CUDSS_LIBRARY CUDSS_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..6c15bde147469ddc84980dca0c756e8f26e1ddb1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake
@@ -0,0 +1,67 @@
+# Find the CUSPARSELT library
+#
+# The following variables are optionally searched for defaults
+#  CUSPARSELT_ROOT: Base directory where CUSPARSELT is found
+#  CUSPARSELT_INCLUDE_DIR: Directory where CUSPARSELT header is searched for
+#  CUSPARSELT_LIBRARY: Directory where CUSPARSELT library is searched for
+#
+# The following are set after configuration is done:
+#  CUSPARSELT_FOUND
+#  CUSPARSELT_INCLUDE_PATH
+#  CUSPARSELT_LIBRARY_PATH
+
+include(FindPackageHandleStandardArgs)
+
+set(CUSPARSELT_ROOT $ENV{CUSPARSELT_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuSPARSELt")
+if (DEFINED $ENV{CUSPARSELT_ROOT_DIR})
+  message(WARNING "CUSPARSELT_ROOT_DIR is deprecated. Please set CUSPARSELT_ROOT instead.")
+endif()
+list(APPEND CUSPARSELT_ROOT $ENV{CUSPARSELT_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+
+# Compatible layer for CMake <3.12. CUSPARSELT_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+list(APPEND CMAKE_PREFIX_PATH ${CUSPARSELT_ROOT})
+
+set(CUSPARSELT_INCLUDE_DIR $ENV{CUSPARSELT_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuSPARSELt header files")
+
+find_path(CUSPARSELT_INCLUDE_PATH cusparseLt.h
+  HINTS ${CUSPARSELT_INCLUDE_DIR}
+  PATH_SUFFIXES cuda/include cuda include)
+
+set(CUSPARSELT_LIBRARY $ENV{CUSPARSELT_LIBRARY} CACHE PATH "Path to the cusparselt library file (e.g., libcusparseLt.so)")
+
+set(CUSPARSELT_LIBRARY_NAME "libcusparseLt.so")
+if(MSVC)
+  set(CUSPARSELT_LIBRARY_NAME "cusparseLt.lib")
+endif()
+
+find_library(CUSPARSELT_LIBRARY_PATH ${CUSPARSELT_LIBRARY_NAME}
+  PATHS ${CUSPARSELT_LIBRARY}
+  PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
+
+find_package_handle_standard_args(CUSPARSELT DEFAULT_MSG CUSPARSELT_LIBRARY_PATH CUSPARSELT_INCLUDE_PATH)
+
+if(CUSPARSELT_FOUND)
+  # Get cuSPARSELt version
+  file(READ ${CUSPARSELT_INCLUDE_PATH}/cusparseLt.h CUSPARSELT_HEADER_CONTENTS)
+  string(REGEX MATCH "define CUSPARSELT_VER_MAJOR * +([0-9]+)"
+               CUSPARSELT_VERSION_MAJOR "${CUSPARSELT_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUSPARSELT_VER_MAJOR * +([0-9]+)" "\\1"
+               CUSPARSELT_VERSION_MAJOR "${CUSPARSELT_VERSION_MAJOR}")
+  string(REGEX MATCH "define CUSPARSELT_VER_MINOR * +([0-9]+)"
+               CUSPARSELT_VERSION_MINOR "${CUSPARSELT_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUSPARSELT_VER_MINOR * +([0-9]+)" "\\1"
+               CUSPARSELT_VERSION_MINOR "${CUSPARSELT_VERSION_MINOR}")
+  string(REGEX MATCH "define CUSPARSELT_VER_PATCH * +([0-9]+)"
+               CUSPARSELT_VERSION_PATCH "${CUSPARSELT_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUSPARSELT_VER_PATCH * +([0-9]+)" "\\1"
+               CUSPARSELT_VERSION_PATCH "${CUSPARSELT_VERSION_PATCH}")
+  # Assemble cuSPARSELt version. Use minor version since current major version is 0.
+  if(NOT CUSPARSELT_VERSION_MINOR)
+    set(CUSPARSELT_VERSION "?")
+  else()
+    set(CUSPARSELT_VERSION
+        "${CUSPARSELT_VERSION_MAJOR}.${CUSPARSELT_VERSION_MINOR}.${CUSPARSELT_VERSION_PATCH}")
+  endif()
+endif()
+
+mark_as_advanced(CUSPARSELT_ROOT CUSPARSELT_INCLUDE_DIR CUSPARSELT_LIBRARY CUSPARSELT_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..337afa1bfe4178d1af041c6504c1124b8c31d482
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake
@@ -0,0 +1,141 @@
+# This will define the following variables:
+# SYCL_FOUND               : True if the system has the SYCL library.
+# SYCL_INCLUDE_DIR         : Include directories needed to use SYCL.
+# SYCL_LIBRARY_DIR         : The path to the SYCL library.
+# SYCL_LIBRARY             : SYCL library fullname.
+# SYCL_COMPILER_VERSION    : SYCL compiler version.
+
+include(FindPackageHandleStandardArgs)
+
+set(SYCL_ROOT "")
+if(DEFINED ENV{SYCL_ROOT})
+  set(SYCL_ROOT $ENV{SYCL_ROOT})
+elseif(DEFINED ENV{CMPLR_ROOT})
+  set(SYCL_ROOT $ENV{CMPLR_ROOT})
+else()
+  # Use the default path to ensure proper linking with torch::xpurt when the user is working with libtorch.
+  if(CMAKE_SYSTEM_NAME MATCHES "Linux")
+    set(SYCL_ROOT "/opt/intel/oneapi/compiler/latest")
+  elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
+    set(SYCL_ROOT "C:/Program Files (x86)/Intel/oneAPI/compiler/latest")
+  endif()
+  if(NOT EXISTS ${SYCL_ROOT})
+    set(SYCL_ROOT "")
+  endif()
+endif()
+
+string(COMPARE EQUAL "${SYCL_ROOT}" "" nosyclfound)
+if(nosyclfound)
+  set(SYCL_FOUND False)
+  set(SYCL_REASON_FAILURE "SYCL library not set!!")
+  set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
+  return()
+endif()
+
+# Find SYCL compiler executable.
+find_program(
+  SYCL_COMPILER
+  NAMES icx
+  PATHS "${SYCL_ROOT}"
+  PATH_SUFFIXES bin bin64
+  NO_DEFAULT_PATH
+  )
+
+function(parse_sycl_compiler_version version_number)
+  # Execute the SYCL compiler with the --version flag to match the version string.
+  execute_process(COMMAND ${SYCL_COMPILER} --version OUTPUT_VARIABLE SYCL_VERSION_STRING)
+  string(REGEX REPLACE "Intel\\(R\\) (.*) Compiler ([0-9]+\\.[0-9]+\\.[0-9]+) (.*)" "\\2"
+               SYCL_VERSION_STRING_MATCH ${SYCL_VERSION_STRING})
+  string(REPLACE "." ";" SYCL_VERSION_LIST ${SYCL_VERSION_STRING_MATCH})
+  # Split the version number list into major, minor, and patch components.
+  list(GET SYCL_VERSION_LIST 0 VERSION_MAJOR)
+  list(GET SYCL_VERSION_LIST 1 VERSION_MINOR)
+  list(GET SYCL_VERSION_LIST 2 VERSION_PATCH)
+  # Calculate the version number in the format XXXXYYZZ, using the formula (major * 10000 + minor * 100 + patch).
+  math(EXPR VERSION_NUMBER_MATCH "${VERSION_MAJOR} * 10000 + ${VERSION_MINOR} * 100 + ${VERSION_PATCH}")
+  set(${version_number} "${VERSION_NUMBER_MATCH}" PARENT_SCOPE)
+endfunction()
+
+if(SYCL_COMPILER)
+  parse_sycl_compiler_version(SYCL_COMPILER_VERSION)
+endif()
+
+if(NOT SYCL_COMPILER_VERSION)
+  set(SYCL_FOUND False)
+  set(SYCL_REASON_FAILURE "Cannot parse sycl compiler version to get SYCL_COMPILER_VERSION!")
+  set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
+  return()
+endif()
+
+# Find include path from binary.
+find_file(
+  SYCL_INCLUDE_DIR
+  NAMES include
+  HINTS ${SYCL_ROOT}
+  NO_DEFAULT_PATH
+  )
+
+# Find include/sycl path from include path.
+find_file(
+  SYCL_INCLUDE_SYCL_DIR
+  NAMES sycl
+  HINTS ${SYCL_ROOT}/include/
+  NO_DEFAULT_PATH
+  )
+
+# Due to the unrecognized compilation option `-fsycl` in other compiler.
+list(APPEND SYCL_INCLUDE_DIR ${SYCL_INCLUDE_SYCL_DIR})
+
+# Find library directory from binary.
+find_file(
+  SYCL_LIBRARY_DIR
+  NAMES lib lib64
+  HINTS ${SYCL_ROOT}
+  NO_DEFAULT_PATH
+  )
+
+# Define the old version of SYCL toolkit that is compatible with the current version of PyTorch.
+set(PYTORCH_2_5_SYCL_TOOLKIT_VERSION 20249999)
+
+# By default, we use libsycl.so on Linux and sycl.lib on Windows as the SYCL library name.
+if (SYCL_COMPILER_VERSION VERSION_LESS_EQUAL PYTORCH_2_5_SYCL_TOOLKIT_VERSION)
+  # Don't use if(WIN32) here since this requires cmake>=3.25 and file is installed
+  # and used by other projects.
+  # See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html
+  if(CMAKE_SYSTEM_NAME MATCHES "Windows")
+    # On Windows, the SYCL library is named sycl7.lib until PYTORCH_2_5_SYCL_TOOLKIT_VERSION.
+    # sycl.lib is supported in the later version.
+    set(sycl_lib_suffix "7")
+  endif()
+endif()
+
+# Find SYCL library fullname.
+find_library(
+  SYCL_LIBRARY
+  NAMES "sycl${sycl_lib_suffix}"
+  HINTS ${SYCL_LIBRARY_DIR}
+  NO_DEFAULT_PATH
+)
+
+# Find OpenCL library fullname, which is a dependency of oneDNN.
+find_library(
+  OCL_LIBRARY
+  NAMES OpenCL
+  HINTS ${SYCL_LIBRARY_DIR}
+  NO_DEFAULT_PATH
+)
+
+if((NOT SYCL_LIBRARY) OR (NOT OCL_LIBRARY))
+  set(SYCL_FOUND False)
+  set(SYCL_REASON_FAILURE "SYCL library is incomplete!!")
+  set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
+  return()
+endif()
+
+find_package_handle_standard_args(
+  SYCL
+  FOUND_VAR SYCL_FOUND
+  REQUIRED_VARS SYCL_INCLUDE_DIR SYCL_LIBRARY_DIR SYCL_LIBRARY
+  REASON_FAILURE_MESSAGE "${SYCL_REASON_FAILURE}"
+  VERSION_VAR SYCL_COMPILER_VERSION
+  )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..55c4e83012d820995f59b717ecb676452f9ccbec
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake
@@ -0,0 +1,10 @@
+# This is a wrapper of the upstream `./upstream/FindCUDA.cmake` that
+# automatically includes `./upstream/CMakeInitializeConfigs.cmake` before
+# `./upstream/FindCUDA.cmake`. The `CMakeInitializeConfigs.cmake`, which is
+# absent in old CMake versions, creates some necessary variables for the later
+# to run.
+# See ./README.md for details.
+
+set(UPSTREAM_FIND_CUDA_DIR "${CMAKE_CURRENT_LIST_DIR}/upstream/")
+
+include("${UPSTREAM_FIND_CUDA_DIR}/FindCUDA.cmake")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..82134328c803dc87a89564638540a6cbcfa2d906
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake
@@ -0,0 +1,78 @@
+# Find the CUDNN libraries
+#
+# The following variables are optionally searched for defaults
+#  CUDNN_ROOT: Base directory where CUDNN is found
+#  CUDNN_INCLUDE_DIR: Directory where CUDNN header is searched for
+#  CUDNN_LIBRARY: Directory where CUDNN library is searched for
+#  CUDNN_STATIC: Are we looking for a static library? (default: no)
+#
+# The following are set after configuration is done:
+#  CUDNN_FOUND
+#  CUDNN_INCLUDE_PATH
+#  CUDNN_LIBRARY_PATH
+#
+
+include(FindPackageHandleStandardArgs)
+
+set(CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuDNN")
+if (DEFINED $ENV{CUDNN_ROOT_DIR})
+  message(WARNING "CUDNN_ROOT_DIR is deprecated. Please set CUDNN_ROOT instead.")
+endif()
+list(APPEND CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+
+# Compatible layer for CMake <3.12. CUDNN_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+list(APPEND CMAKE_PREFIX_PATH ${CUDNN_ROOT})
+
+set(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuDNN header files")
+
+find_path(CUDNN_INCLUDE_PATH cudnn.h
+  HINTS ${CUDNN_INCLUDE_DIR}
+  PATH_SUFFIXES cuda/include cuda include)
+
+option(CUDNN_STATIC "Look for static CUDNN" OFF)
+if (CUDNN_STATIC)
+  set(CUDNN_LIBNAME "libcudnn_static.a")
+else()
+  set(CUDNN_LIBNAME "cudnn")
+endif()
+
+set(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY} CACHE PATH "Path to the cudnn library file (e.g., libcudnn.so)")
+if (CUDNN_LIBRARY MATCHES ".*cudnn_static.a" AND NOT CUDNN_STATIC)
+  message(WARNING "CUDNN_LIBRARY points to a static library (${CUDNN_LIBRARY}) but CUDNN_STATIC is OFF.")
+endif()
+
+find_library(CUDNN_LIBRARY_PATH ${CUDNN_LIBNAME}
+  PATHS ${CUDNN_LIBRARY}
+  PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
+
+find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_LIBRARY_PATH CUDNN_INCLUDE_PATH)
+
+if(CUDNN_FOUND)
+  # Get cuDNN version
+  if(EXISTS ${CUDNN_INCLUDE_PATH}/cudnn_version.h)
+    file(READ ${CUDNN_INCLUDE_PATH}/cudnn_version.h CUDNN_HEADER_CONTENTS)
+  else()
+    file(READ ${CUDNN_INCLUDE_PATH}/cudnn.h CUDNN_HEADER_CONTENTS)
+  endif()
+  string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)"
+               CUDNN_VERSION_MAJOR "${CUDNN_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1"
+               CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}")
+  string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)"
+               CUDNN_VERSION_MINOR "${CUDNN_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1"
+               CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}")
+  string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)"
+               CUDNN_VERSION_PATCH "${CUDNN_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1"
+               CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}")
+  # Assemble cuDNN version
+  if(NOT CUDNN_VERSION_MAJOR)
+    set(CUDNN_VERSION "?")
+  else()
+    set(CUDNN_VERSION
+        "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}")
+  endif()
+endif()
+
+mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..5517e8f0624b1e5538b761e1f4891227007d0045
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake
@@ -0,0 +1,40 @@
+# Distributed under the OSI-approved BSD 3-Clause License.  See accompanying
+# file Copyright.txt or https://cmake.org/licensing for details.
+
+# Present in upstream, but not supported on versions of cmake we need to support
+# include_guard(GLOBAL)
+
+# Initializes `<_PREFIX>_` variables from the corresponding
+# `<_PREFIX>__INIT`, for the configurations currently used.
+function(cmake_initialize_per_config_variable _PREFIX _DOCSTRING)
+  string(STRIP "${${_PREFIX}_INIT}" _INIT)
+  set("${_PREFIX}" "${_INIT}"
+    CACHE STRING "${_DOCSTRING} during all build types.")
+  mark_as_advanced("${_PREFIX}")
+
+  if (NOT CMAKE_NOT_USING_CONFIG_FLAGS)
+    set(_CONFIGS Debug Release MinSizeRel RelWithDebInfo)
+
+    get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
+    if (_GENERATOR_IS_MULTI_CONFIG)
+      list(APPEND _CONFIGS ${CMAKE_CONFIGURATION_TYPES})
+    else()
+      if (NOT CMAKE_NO_BUILD_TYPE)
+        set(CMAKE_BUILD_TYPE "${CMAKE_BUILD_TYPE_INIT}" CACHE STRING
+          "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel ...")
+      endif()
+      list(APPEND _CONFIGS ${CMAKE_BUILD_TYPE})
+    endif()
+
+    list(REMOVE_DUPLICATES _CONFIGS)
+    foreach(_BUILD_TYPE IN LISTS _CONFIGS)
+      if (NOT "${_BUILD_TYPE}" STREQUAL "")
+        string(TOUPPER "${_BUILD_TYPE}" _BUILD_TYPE)
+        string(STRIP "${${_PREFIX}_${_BUILD_TYPE}_INIT}" _INIT)
+        set("${_PREFIX}_${_BUILD_TYPE}" "${_INIT}"
+          CACHE STRING "${_DOCSTRING} during ${_BUILD_TYPE} builds.")
+        mark_as_advanced("${_PREFIX}_${_BUILD_TYPE}")
+      endif()
+    endforeach()
+  endif()
+endfunction()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..411a246656b3bdaba6abc238fd35caf959c9cca0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake
@@ -0,0 +1,1981 @@
+#.rst:
+# FindCUDA
+# --------
+#
+# .. note::
+#
+#   The FindCUDA module has been superseded by first-class support
+#   for the CUDA language in CMake.  It is no longer necessary to
+#   use this module or call ``find_package(CUDA)``.  This module
+#   now exists only for compatibility with projects that have not
+#   been ported.
+#
+#   Instead, list ``CUDA`` among the languages named in the top-level
+#   call to the :command:`project` command, or call the
+#   :command:`enable_language` command with ``CUDA``.
+#   Then one can add CUDA (``.cu``) sources to programs directly
+#   in calls to :command:`add_library` and :command:`add_executable`.
+#
+# Tools for building CUDA C files: libraries and build dependencies.
+#
+# This script locates the NVIDIA CUDA C tools.  It should work on Linux,
+# Windows, and macOS and should be reasonably up to date with CUDA C
+# releases.
+#
+# This script makes use of the standard :command:`find_package` arguments of
+# ````, ``REQUIRED`` and ``QUIET``.  ``CUDA_FOUND`` will report if an
+# acceptable version of CUDA was found.
+#
+# The script will prompt the user to specify ``CUDA_TOOLKIT_ROOT_DIR`` if
+# the prefix cannot be determined by the location of nvcc in the system
+# path and ``REQUIRED`` is specified to :command:`find_package`.  To use
+# a different installed version of the toolkit set the environment variable
+# ``CUDA_BIN_PATH`` before running cmake (e.g.
+# ``CUDA_BIN_PATH=/usr/local/cuda1.0`` instead of the default
+# ``/usr/local/cuda``) or set ``CUDA_TOOLKIT_ROOT_DIR`` after configuring.  If
+# you change the value of ``CUDA_TOOLKIT_ROOT_DIR``, various components that
+# depend on the path will be relocated.
+#
+# It might be necessary to set ``CUDA_TOOLKIT_ROOT_DIR`` manually on certain
+# platforms, or to use a CUDA runtime not installed in the default
+# location.  In newer versions of the toolkit the CUDA library is
+# included with the graphics driver -- be sure that the driver version
+# matches what is needed by the CUDA runtime version.
+#
+# The following variables affect the behavior of the macros in the
+# script (in alphebetical order).  Note that any of these flags can be
+# changed multiple times in the same directory before calling
+# ``CUDA_ADD_EXECUTABLE``, ``CUDA_ADD_LIBRARY``, ``CUDA_COMPILE``,
+# ``CUDA_COMPILE_PTX``, ``CUDA_COMPILE_FATBIN``, ``CUDA_COMPILE_CUBIN``
+# or ``CUDA_WRAP_SRCS``::
+#
+#   CUDA_64_BIT_DEVICE_CODE (Default matches host bit size)
+#   -- Set to ON to compile for 64 bit device code, OFF for 32 bit device code.
+#      Note that making this different from the host code when generating object
+#      or C files from CUDA code just won't work, because size_t gets defined by
+#      nvcc in the generated source.  If you compile to PTX and then load the
+#      file yourself, you can mix bit sizes between device and host.
+#
+#   CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE (Default ON)
+#   -- Set to ON if you want the custom build rule to be attached to the source
+#      file in Visual Studio.  Turn OFF if you add the same cuda file to multiple
+#      targets.
+#
+#      This allows the user to build the target from the CUDA file; however, bad
+#      things can happen if the CUDA source file is added to multiple targets.
+#      When performing parallel builds it is possible for the custom build
+#      command to be run more than once and in parallel causing cryptic build
+#      errors.  VS runs the rules for every source file in the target, and a
+#      source can have only one rule no matter how many projects it is added to.
+#      When the rule is run from multiple targets race conditions can occur on
+#      the generated file.  Eventually everything will get built, but if the user
+#      is unaware of this behavior, there may be confusion.  It would be nice if
+#      this script could detect the reuse of source files across multiple targets
+#      and turn the option off for the user, but no good solution could be found.
+#
+#   CUDA_BUILD_CUBIN (Default OFF)
+#   -- Set to ON to enable and extra compilation pass with the -cubin option in
+#      Device mode. The output is parsed and register, shared memory usage is
+#      printed during build.
+#
+#   CUDA_BUILD_EMULATION (Default OFF for device mode)
+#   -- Set to ON for Emulation mode. -D_DEVICEEMU is defined for CUDA C files
+#      when CUDA_BUILD_EMULATION is TRUE.
+#
+#   CUDA_LINK_LIBRARIES_KEYWORD (Default "")
+#    -- The  keyword to use for internal
+#       target_link_libraries calls. The default is to use no keyword which
+#       uses the old "plain" form of target_link_libraries. Note that is matters
+#       because whatever is used inside the FindCUDA module must also be used
+#       outside - the two forms of target_link_libraries cannot be mixed.
+#
+#   CUDA_GENERATED_OUTPUT_DIR (Default CMAKE_CURRENT_BINARY_DIR)
+#   -- Set to the path you wish to have the generated files placed.  If it is
+#      blank output files will be placed in CMAKE_CURRENT_BINARY_DIR.
+#      Intermediate files will always be placed in
+#      CMAKE_CURRENT_BINARY_DIR/CMakeFiles.
+#
+#   CUDA_HOST_COMPILATION_CPP (Default ON)
+#   -- Set to OFF for C compilation of host code.
+#
+#   CUDA_HOST_COMPILER (Default CMAKE_C_COMPILER)
+#   -- Set the host compiler to be used by nvcc.  Ignored if -ccbin or
+#      --compiler-bindir is already present in the CUDA_NVCC_FLAGS or
+#      CUDA_NVCC_FLAGS_ variables.  For Visual Studio targets,
+#      the host compiler is constructed with one or more visual studio macros
+#      such as $(VCInstallDir), that expands out to the path when
+#      the command is run from within VS.
+#      If the CUDAHOSTCXX environment variable is set it will
+#      be used as the default.
+#
+#   CUDA_NVCC_FLAGS
+#   CUDA_NVCC_FLAGS_
+#   -- Additional NVCC command line arguments.  NOTE: multiple arguments must be
+#      semi-colon delimited (e.g. --compiler-options;-Wall)
+#
+#   CUDA_PROPAGATE_HOST_FLAGS (Default ON)
+#   -- Set to ON to propagate CMAKE_{C,CXX}_FLAGS and their configuration
+#      dependent counterparts (e.g. CMAKE_C_FLAGS_DEBUG) automatically to the
+#      host compiler through nvcc's -Xcompiler flag.  This helps make the
+#      generated host code match the rest of the system better.  Sometimes
+#      certain flags give nvcc problems, and this will help you turn the flag
+#      propagation off.  This does not affect the flags supplied directly to nvcc
+#      via CUDA_NVCC_FLAGS or through the OPTION flags specified through
+#      CUDA_ADD_LIBRARY, CUDA_ADD_EXECUTABLE, or CUDA_WRAP_SRCS.  Flags used for
+#      shared library compilation are not affected by this flag.
+#
+#   CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST (Default "")
+#   -- A list containing the host flags that should not be propagated when
+#      CUDA_PROPAGATE_HOST_FLAGS is ON.
+#
+#   CUDA_SEPARABLE_COMPILATION (Default OFF)
+#   -- If set this will enable separable compilation for all CUDA runtime object
+#      files.  If used outside of CUDA_ADD_EXECUTABLE and CUDA_ADD_LIBRARY
+#      (e.g. calling CUDA_WRAP_SRCS directly),
+#      CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME and
+#      CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS should be called.
+#
+#   CUDA_SOURCE_PROPERTY_FORMAT
+#   -- If this source file property is set, it can override the format specified
+#      to CUDA_WRAP_SRCS (OBJ, PTX, CUBIN, or FATBIN).  If an input source file
+#      is not a .cu file, setting this file will cause it to be treated as a .cu
+#      file. See documentation for set_source_files_properties on how to set
+#      this property.
+#
+#   CUDA_USE_STATIC_CUDA_RUNTIME (Default ON)
+#   -- When enabled the static version of the CUDA runtime library will be used
+#      in CUDA_LIBRARIES.  If the version of CUDA configured doesn't support
+#      this option, then it will be silently disabled.
+#
+#   CUDA_VERBOSE_BUILD (Default OFF)
+#   -- Set to ON to see all the commands used when building the CUDA file.  When
+#      using a Makefile generator the value defaults to VERBOSE (run make
+#      VERBOSE=1 to see output), although setting CUDA_VERBOSE_BUILD to ON will
+#      always print the output.
+#
+# The script creates the following macros (in alphebetical order)::
+#
+#   CUDA_ADD_CUFFT_TO_TARGET( cuda_target )
+#   -- Adds the cufft library to the target (can be any target).  Handles whether
+#      you are in emulation mode or not.
+#
+#   CUDA_ADD_CUBLAS_TO_TARGET( cuda_target )
+#   -- Adds the cublas library to the target (can be any target).  Handles
+#      whether you are in emulation mode or not.
+#
+#   CUDA_ADD_EXECUTABLE( cuda_target file0 file1 ...
+#                        [WIN32] [MACOSX_BUNDLE] [EXCLUDE_FROM_ALL] [OPTIONS ...] )
+#   -- Creates an executable "cuda_target" which is made up of the files
+#      specified.  All of the non CUDA C files are compiled using the standard
+#      build rules specified by CMAKE and the cuda files are compiled to object
+#      files using nvcc and the host compiler.  In addition CUDA_INCLUDE_DIRS is
+#      added automatically to include_directories().  Some standard CMake target
+#      calls can be used on the target after calling this macro
+#      (e.g. set_target_properties and target_link_libraries), but setting
+#      properties that adjust compilation flags will not affect code compiled by
+#      nvcc.  Such flags should be modified before calling CUDA_ADD_EXECUTABLE,
+#      CUDA_ADD_LIBRARY or CUDA_WRAP_SRCS.
+#
+#   CUDA_ADD_LIBRARY( cuda_target file0 file1 ...
+#                     [STATIC | SHARED | MODULE] [EXCLUDE_FROM_ALL] [OPTIONS ...] )
+#   -- Same as CUDA_ADD_EXECUTABLE except that a library is created.
+#
+#   CUDA_BUILD_CLEAN_TARGET()
+#   -- Creates a convenience target that deletes all the dependency files
+#      generated.  You should make clean after running this target to ensure the
+#      dependency files get regenerated.
+#
+#   CUDA_COMPILE( generated_files file0 file1 ... [STATIC | SHARED | MODULE]
+#                 [OPTIONS ...] )
+#   -- Returns a list of generated files from the input source files to be used
+#      with ADD_LIBRARY or ADD_EXECUTABLE.
+#
+#   CUDA_COMPILE_PTX( generated_files file0 file1 ... [OPTIONS ...] )
+#   -- Returns a list of PTX files generated from the input source files.
+#
+#   CUDA_COMPILE_FATBIN( generated_files file0 file1 ... [OPTIONS ...] )
+#   -- Returns a list of FATBIN files generated from the input source files.
+#
+#   CUDA_COMPILE_CUBIN( generated_files file0 file1 ... [OPTIONS ...] )
+#   -- Returns a list of CUBIN files generated from the input source files.
+#
+#   CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME( output_file_var
+#                                                        cuda_target
+#                                                        object_files )
+#   -- Compute the name of the intermediate link file used for separable
+#      compilation.  This file name is typically passed into
+#      CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS.  output_file_var is produced
+#      based on cuda_target the list of objects files that need separable
+#      compilation as specified by object_files.  If the object_files list is
+#      empty, then output_file_var will be empty.  This function is called
+#      automatically for CUDA_ADD_LIBRARY and CUDA_ADD_EXECUTABLE.  Note that
+#      this is a function and not a macro.
+#
+#   CUDA_INCLUDE_DIRECTORIES( path0 path1 ... )
+#   -- Sets the directories that should be passed to nvcc
+#      (e.g. nvcc -Ipath0 -Ipath1 ... ). These paths usually contain other .cu
+#      files.
+#
+#
+#   CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS( output_file_var cuda_target
+#                                            nvcc_flags object_files)
+#   -- Generates the link object required by separable compilation from the given
+#      object files.  This is called automatically for CUDA_ADD_EXECUTABLE and
+#      CUDA_ADD_LIBRARY, but can be called manually when using CUDA_WRAP_SRCS
+#      directly.  When called from CUDA_ADD_LIBRARY or CUDA_ADD_EXECUTABLE the
+#      nvcc_flags passed in are the same as the flags passed in via the OPTIONS
+#      argument.  The only nvcc flag added automatically is the bitness flag as
+#      specified by CUDA_64_BIT_DEVICE_CODE.  Note that this is a function
+#      instead of a macro.
+#
+#   CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
+#   -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
+#      target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
+#       - "Auto" detects local machine GPU compute arch at runtime.
+#       - "Common" and "All" cover common and entire subsets of architectures
+#      ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
+#      NAME: Kepler Maxwell Kepler+Tesla Maxwell+Tegra Pascal Volta Turing
+#      NUM: Any number. Only those pairs are currently accepted by NVCC though:
+#            3.5 3.7 5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5
+#      Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
+#      Additionally, sets ${out_variable}_readable to the resulting numeric list
+#      Example:
+#       CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
+#        LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
+#
+#      More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
+#      Note that this is a function instead of a macro.
+#
+#   CUDA_WRAP_SRCS ( cuda_target format generated_files file0 file1 ...
+#                    [STATIC | SHARED | MODULE] [OPTIONS ...] )
+#   -- This is where all the magic happens.  CUDA_ADD_EXECUTABLE,
+#      CUDA_ADD_LIBRARY, CUDA_COMPILE, and CUDA_COMPILE_PTX all call this
+#      function under the hood.
+#
+#      Given the list of files (file0 file1 ... fileN) this macro generates
+#      custom commands that generate either PTX or linkable objects (use "PTX" or
+#      "OBJ" for the format argument to switch).  Files that don't end with .cu
+#      or have the HEADER_FILE_ONLY property are ignored.
+#
+#      The arguments passed in after OPTIONS are extra command line options to
+#      give to nvcc.  You can also specify per configuration options by
+#      specifying the name of the configuration followed by the options.  General
+#      options must precede configuration specific options.  Not all
+#      configurations need to be specified, only the ones provided will be used.
+#
+#         OPTIONS -DFLAG=2 "-DFLAG_OTHER=space in flag"
+#         DEBUG -g
+#         RELEASE --use_fast_math
+#         RELWITHDEBINFO --use_fast_math;-g
+#         MINSIZEREL --use_fast_math
+#
+#      For certain configurations (namely VS generating object files with
+#      CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE set to ON), no generated file will
+#      be produced for the given cuda file.  This is because when you add the
+#      cuda file to Visual Studio it knows that this file produces an object file
+#      and will link in the resulting object file automatically.
+#
+#      This script will also generate a separate cmake script that is used at
+#      build time to invoke nvcc.  This is for several reasons.
+#
+#        1. nvcc can return negative numbers as return values which confuses
+#        Visual Studio into thinking that the command succeeded.  The script now
+#        checks the error codes and produces errors when there was a problem.
+#
+#        2. nvcc has been known to not delete incomplete results when it
+#        encounters problems.  This confuses build systems into thinking the
+#        target was generated when in fact an unusable file exists.  The script
+#        now deletes the output files if there was an error.
+#
+#        3. By putting all the options that affect the build into a file and then
+#        make the build rule dependent on the file, the output files will be
+#        regenerated when the options change.
+#
+#      This script also looks at optional arguments STATIC, SHARED, or MODULE to
+#      determine when to target the object compilation for a shared library.
+#      BUILD_SHARED_LIBS is ignored in CUDA_WRAP_SRCS, but it is respected in
+#      CUDA_ADD_LIBRARY.  On some systems special flags are added for building
+#      objects intended for shared libraries.  A preprocessor macro,
+#      _EXPORTS is defined when a shared library compilation is
+#      detected.
+#
+#      Flags passed into add_definitions with -D or /D are passed along to nvcc.
+#
+#
+#
+# The script defines the following variables::
+#
+#   CUDA_VERSION_MAJOR    -- The major version of cuda as reported by nvcc.
+#   CUDA_VERSION_MINOR    -- The minor version.
+#   CUDA_VERSION
+#   CUDA_VERSION_STRING   -- CUDA_VERSION_MAJOR.CUDA_VERSION_MINOR
+#   CUDA_HAS_FP16         -- Whether a short float (float16,fp16) is supported.
+#
+#   CUDA_TOOLKIT_ROOT_DIR -- Path to the CUDA Toolkit (defined if not set).
+#   CUDA_SDK_ROOT_DIR     -- Path to the CUDA SDK.  Use this to find files in the
+#                            SDK.  This script will not directly support finding
+#                            specific libraries or headers, as that isn't
+#                            supported by NVIDIA.  If you want to change
+#                            libraries when the path changes see the
+#                            FindCUDA.cmake script for an example of how to clear
+#                            these variables.  There are also examples of how to
+#                            use the CUDA_SDK_ROOT_DIR to locate headers or
+#                            libraries, if you so choose (at your own risk).
+#   CUDA_INCLUDE_DIRS     -- Include directory for cuda headers.  Added automatically
+#                            for CUDA_ADD_EXECUTABLE and CUDA_ADD_LIBRARY.
+#   CUDA_LIBRARIES        -- Cuda RT library.
+#   CUDA_CUFFT_LIBRARIES  -- Device or emulation library for the Cuda FFT
+#                            implementation (alternative to:
+#                            CUDA_ADD_CUFFT_TO_TARGET macro)
+#   CUDA_CUBLAS_LIBRARIES -- Device or emulation library for the Cuda BLAS
+#                            implementation (alternative to:
+#                            CUDA_ADD_CUBLAS_TO_TARGET macro).
+#   CUDA_cudart_static_LIBRARY -- Statically linkable cuda runtime library.
+#                                 Only available for CUDA version 5.5+
+#   CUDA_cudadevrt_LIBRARY -- Device runtime library.
+#                             Required for separable compilation.
+#   CUDA_cupti_LIBRARY    -- CUDA Profiling Tools Interface library.
+#                            Only available for CUDA version 4.0+.
+#   CUDA_curand_LIBRARY   -- CUDA Random Number Generation library.
+#                            Only available for CUDA version 3.2+.
+#   CUDA_cusolver_LIBRARY -- CUDA Direct Solver library.
+#                            Only available for CUDA version 7.0+.
+#   CUDA_cusparse_LIBRARY -- CUDA Sparse Matrix library.
+#                            Only available for CUDA version 3.2+.
+#   CUDA_npp_LIBRARY      -- NVIDIA Performance Primitives lib.
+#                            Only available for CUDA version 4.0+.
+#   CUDA_nppc_LIBRARY     -- NVIDIA Performance Primitives lib (core).
+#                            Only available for CUDA version 5.5+.
+#   CUDA_nppi_LIBRARY     -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 5.5 - 8.0.
+#   CUDA_nppial_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppicc_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppicom_LIBRARY  -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppidei_LIBRARY  -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppif_LIBRARY    -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppig_LIBRARY    -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppim_LIBRARY    -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppist_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppisu_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppitc_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_npps_LIBRARY     -- NVIDIA Performance Primitives lib (signal processing).
+#                            Only available for CUDA version 5.5+.
+#   CUDA_nvcuvenc_LIBRARY -- CUDA Video Encoder library.
+#                            Only available for CUDA version 3.2+.
+#                            Windows only.
+#   CUDA_nvcuvid_LIBRARY  -- CUDA Video Decoder library.
+#                            Only available for CUDA version 3.2+.
+#                            Windows only.
+#
+
+#   James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#   Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#   Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#   Copyright (c) 2007-2009
+#   Scientific Computing and Imaging Institute, University of Utah
+#
+#   This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#   for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+###############################################################################
+
+# FindCUDA.cmake
+
+include(FindPackageHandleStandardArgs)
+# This macro helps us find the location of helper files we will need the full path to
+macro(CUDA_FIND_HELPER_FILE _name _extension)
+  set(_full_name "${_name}.${_extension}")
+  # CMAKE_CURRENT_LIST_FILE contains the full path to the file currently being
+  # processed.  Using this variable, we can pull out the current path, and
+  # provide a way to get access to the other files we need local to here.
+  get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
+  set(CUDA_${_name} "${CMAKE_CURRENT_LIST_DIR}/FindCUDA/${_full_name}")
+  if(NOT EXISTS "${CUDA_${_name}}")
+    set(error_message "${_full_name} not found in ${CMAKE_CURRENT_LIST_DIR}/FindCUDA")
+    if(CUDA_FIND_REQUIRED)
+      message(FATAL_ERROR "${error_message}")
+    else()
+      if(NOT CUDA_FIND_QUIETLY)
+        message(STATUS "${error_message}")
+      endif()
+    endif()
+  endif()
+  # Set this variable as internal, so the user isn't bugged with it.
+  set(CUDA_${_name} ${CUDA_${_name}} CACHE INTERNAL "Location of ${_full_name}" FORCE)
+endmacro()
+
+#####################################################################
+## CUDA_INCLUDE_NVCC_DEPENDENCIES
+##
+
+# So we want to try and include the dependency file if it exists.  If
+# it doesn't exist then we need to create an empty one, so we can
+# include it.
+
+# If it does exist, then we need to check to see if all the files it
+# depends on exist.  If they don't then we should clear the dependency
+# file and regenerate it later.  This covers the case where a header
+# file has disappeared or moved.
+
+macro(CUDA_INCLUDE_NVCC_DEPENDENCIES dependency_file)
+  set(CUDA_NVCC_DEPEND)
+  set(CUDA_NVCC_DEPEND_REGENERATE FALSE)
+
+
+  # Include the dependency file.  Create it first if it doesn't exist .  The
+  # INCLUDE puts a dependency that will force CMake to rerun and bring in the
+  # new info when it changes.  DO NOT REMOVE THIS (as I did and spent a few
+  # hours figuring out why it didn't work.
+  if(NOT EXISTS ${dependency_file})
+    file(WRITE ${dependency_file} "#FindCUDA.cmake generated file.  Do not edit.\n")
+  endif()
+  # Always include this file to force CMake to run again next
+  # invocation and rebuild the dependencies.
+  #message("including dependency_file = ${dependency_file}")
+  include(${dependency_file})
+
+  # Now we need to verify the existence of all the included files
+  # here.  If they aren't there we need to just blank this variable and
+  # make the file regenerate again.
+#   if(DEFINED CUDA_NVCC_DEPEND)
+#     message("CUDA_NVCC_DEPEND set")
+#   else()
+#     message("CUDA_NVCC_DEPEND NOT set")
+#   endif()
+  if(CUDA_NVCC_DEPEND)
+    #message("CUDA_NVCC_DEPEND found")
+    foreach(f ${CUDA_NVCC_DEPEND})
+      # message("searching for ${f}")
+      if(NOT EXISTS ${f})
+        #message("file ${f} not found")
+        set(CUDA_NVCC_DEPEND_REGENERATE TRUE)
+      endif()
+    endforeach()
+  else()
+    #message("CUDA_NVCC_DEPEND false")
+    # No dependencies, so regenerate the file.
+    set(CUDA_NVCC_DEPEND_REGENERATE TRUE)
+  endif()
+
+  #message("CUDA_NVCC_DEPEND_REGENERATE = ${CUDA_NVCC_DEPEND_REGENERATE}")
+  # No incoming dependencies, so we need to generate them.  Make the
+  # output depend on the dependency file itself, which should cause the
+  # rule to re-run.
+  if(CUDA_NVCC_DEPEND_REGENERATE)
+    set(CUDA_NVCC_DEPEND ${dependency_file})
+    #message("Generating an empty dependency_file: ${dependency_file}")
+    file(WRITE ${dependency_file} "#FindCUDA.cmake generated file.  Do not edit.\n")
+  endif()
+
+endmacro()
+
+###############################################################################
+###############################################################################
+# Setup variables' defaults
+###############################################################################
+###############################################################################
+
+# Allow the user to specify if the device code is supposed to be 32 or 64 bit.
+if(CMAKE_SIZEOF_VOID_P EQUAL 8)
+  set(CUDA_64_BIT_DEVICE_CODE_DEFAULT ON)
+else()
+  set(CUDA_64_BIT_DEVICE_CODE_DEFAULT OFF)
+endif()
+option(CUDA_64_BIT_DEVICE_CODE "Compile device code in 64 bit mode" ${CUDA_64_BIT_DEVICE_CODE_DEFAULT})
+
+# Attach the build rule to the source file in VS.  This option
+option(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE "Attach the build rule to the CUDA source file.  Enable only when the CUDA source file is added to at most one target." ON)
+
+# Prints out extra information about the cuda file during compilation
+option(CUDA_BUILD_CUBIN "Generate and parse .cubin files in Device mode." OFF)
+
+# Set whether we are using emulation or device mode.
+option(CUDA_BUILD_EMULATION "Build in Emulation mode" OFF)
+
+# Where to put the generated output.
+set(CUDA_GENERATED_OUTPUT_DIR "" CACHE PATH "Directory to put all the output files.  If blank it will default to the CMAKE_CURRENT_BINARY_DIR")
+
+# Parse HOST_COMPILATION mode.
+option(CUDA_HOST_COMPILATION_CPP "Generated file extension" ON)
+
+# Extra user settable flags
+cmake_initialize_per_config_variable(CUDA_NVCC_FLAGS "Semi-colon delimit multiple arguments.")
+
+if(DEFINED ENV{CUDAHOSTCXX})
+  set(CUDA_HOST_COMPILER "$ENV{CUDAHOSTCXX}" CACHE FILEPATH "Host side compiler used by NVCC")
+elseif(CMAKE_GENERATOR MATCHES "Visual Studio")
+  set(_CUDA_MSVC_HOST_COMPILER "$(VCInstallDir)Tools/MSVC/$(VCToolsVersion)/bin/Host$(Platform)/$(PlatformTarget)")
+  if(MSVC_VERSION LESS 1910)
+   set(_CUDA_MSVC_HOST_COMPILER "$(VCInstallDir)bin")
+  endif()
+
+  set(CUDA_HOST_COMPILER "${_CUDA_MSVC_HOST_COMPILER}" CACHE FILEPATH "Host side compiler used by NVCC")
+
+else()
+  if(APPLE
+      AND "${CMAKE_C_COMPILER_ID}" MATCHES "Clang"
+      AND "${CMAKE_C_COMPILER}" MATCHES "/cc$")
+    # Using cc which is symlink to clang may let NVCC think it is GCC and issue
+    # unhandled -dumpspecs option to clang. Also in case neither
+    # CMAKE_C_COMPILER is defined (project does not use C language) nor
+    # CUDA_HOST_COMPILER is specified manually we should skip -ccbin and let
+    # nvcc use its own default C compiler.
+    # Only care about this on APPLE with clang to avoid
+    # following symlinks to things like ccache
+    if(DEFINED CMAKE_C_COMPILER AND NOT DEFINED CUDA_HOST_COMPILER)
+      get_filename_component(c_compiler_realpath "${CMAKE_C_COMPILER}" REALPATH)
+      # if the real path does not end up being clang then
+      # go back to using CMAKE_C_COMPILER
+      if(NOT "${c_compiler_realpath}" MATCHES "/clang$")
+        set(c_compiler_realpath "${CMAKE_C_COMPILER}")
+      endif()
+    else()
+      set(c_compiler_realpath "")
+    endif()
+    set(CUDA_HOST_COMPILER "${c_compiler_realpath}" CACHE FILEPATH "Host side compiler used by NVCC")
+  elseif(MSVC AND "${CMAKE_C_COMPILER}" MATCHES "clcache|sccache")
+    # NVCC does not think it will work if it is passed clcache.exe or sccache.exe
+    # as the host compiler, which means that builds with CC=cl.exe won't work.
+    # Best to just feed it whatever the actual cl.exe is as the host compiler.
+    set(CUDA_HOST_COMPILER "cl.exe" CACHE FILEPATH "Host side compiler used by NVCC")
+  else()
+    set(CUDA_HOST_COMPILER "${CMAKE_C_COMPILER}"
+      CACHE FILEPATH "Host side compiler used by NVCC")
+  endif()
+endif()
+
+# Propagate the host flags to the host compiler via -Xcompiler
+option(CUDA_PROPAGATE_HOST_FLAGS "Propagate C/CXX_FLAGS and friends to the host compiler via -Xcompile" ON)
+
+# Blacklisted flags to prevent propagation
+set(CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST  "" CACHE STRING "Blacklisted flags to prevent propagation")
+
+# Enable CUDA_SEPARABLE_COMPILATION
+option(CUDA_SEPARABLE_COMPILATION "Compile CUDA objects with separable compilation enabled.  Requires CUDA 5.0+" OFF)
+
+# Specifies whether the commands used when compiling the .cu file will be printed out.
+option(CUDA_VERBOSE_BUILD "Print out the commands run while compiling the CUDA source file.  With the Makefile generator this defaults to VERBOSE variable specified on the command line, but can be forced on with this option." OFF)
+
+mark_as_advanced(
+  CUDA_64_BIT_DEVICE_CODE
+  CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE
+  CUDA_GENERATED_OUTPUT_DIR
+  CUDA_HOST_COMPILATION_CPP
+  CUDA_NVCC_FLAGS
+  CUDA_PROPAGATE_HOST_FLAGS
+  CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST
+  CUDA_BUILD_CUBIN
+  CUDA_BUILD_EMULATION
+  CUDA_VERBOSE_BUILD
+  CUDA_SEPARABLE_COMPILATION
+  )
+
+# Single config generators like Makefiles or Ninja don't usually have
+# CMAKE_CONFIGURATION_TYPES defined (but note that it can be defined if set by
+# projects or developers). Even CMAKE_BUILD_TYPE might not be defined for
+# single config generators (and should not be defined for multi-config
+# generators). To ensure we get a complete superset of all possible
+# configurations, we combine CMAKE_CONFIGURATION_TYPES, CMAKE_BUILD_TYPE and
+# all of the standard configurations, then weed out duplicates with
+# list(REMOVE_DUPLICATES). Looping over the unique set then ensures we have
+# each configuration-specific set of nvcc flags defined and marked as advanced.
+set(CUDA_configuration_types ${CMAKE_CONFIGURATION_TYPES} ${CMAKE_BUILD_TYPE} Debug MinSizeRel Release RelWithDebInfo)
+list(REMOVE_DUPLICATES CUDA_configuration_types)
+
+###############################################################################
+###############################################################################
+# Locate CUDA, Set Build Type, etc.
+###############################################################################
+###############################################################################
+
+macro(cuda_unset_include_and_libraries)
+  unset(CUDA_TOOLKIT_INCLUDE CACHE)
+  unset(CUDA_CUDART_LIBRARY CACHE)
+  unset(CUDA_CUDA_LIBRARY CACHE)
+  # Make sure you run this before you unset CUDA_VERSION.
+  unset(CUDA_cudart_static_LIBRARY CACHE)
+  unset(CUDA_cudadevrt_LIBRARY CACHE)
+  unset(CUDA_cublas_LIBRARY CACHE)
+  unset(CUDA_cublas_device_LIBRARY CACHE)
+  unset(CUDA_cublasemu_LIBRARY CACHE)
+  unset(CUDA_cublasLt_LIBRARY CACHE)
+  unset(CUDA_cufft_LIBRARY CACHE)
+  unset(CUDA_cufftemu_LIBRARY CACHE)
+  unset(CUDA_cupti_LIBRARY CACHE)
+  unset(CUDA_curand_LIBRARY CACHE)
+  unset(CUDA_cusolver_LIBRARY CACHE)
+  unset(CUDA_cusparse_LIBRARY CACHE)
+  unset(CUDA_npp_LIBRARY CACHE)
+  unset(CUDA_nppc_LIBRARY CACHE)
+  unset(CUDA_nppi_LIBRARY CACHE)
+  unset(CUDA_npps_LIBRARY CACHE)
+  unset(CUDA_nvcuvenc_LIBRARY CACHE)
+  unset(CUDA_nvcuvid_LIBRARY CACHE)
+  unset(CUDA_GPU_DETECT_OUTPUT CACHE)
+endmacro()
+
+# Check to see if the CUDA_TOOLKIT_ROOT_DIR and CUDA_SDK_ROOT_DIR have changed,
+# if they have then clear the cache variables, so that will be detected again.
+if(NOT "${CUDA_TOOLKIT_ROOT_DIR}" STREQUAL "${CUDA_TOOLKIT_ROOT_DIR_INTERNAL}")
+  unset(CUDA_TOOLKIT_TARGET_DIR CACHE)
+  unset(CUDA_NVCC_EXECUTABLE CACHE)
+  cuda_unset_include_and_libraries()
+  unset(CUDA_VERSION CACHE)
+endif()
+
+if(NOT "${CUDA_TOOLKIT_TARGET_DIR}" STREQUAL "${CUDA_TOOLKIT_TARGET_DIR_INTERNAL}")
+  cuda_unset_include_and_libraries()
+endif()
+
+#
+#  End of unset()
+#
+
+#
+#  Start looking for things
+#
+
+# Search for the cuda distribution.
+if(NOT CUDA_TOOLKIT_ROOT_DIR AND NOT CMAKE_CROSSCOMPILING)
+  # Search in the CUDA_BIN_PATH first.
+  find_program(CUDA_TOOLKIT_ROOT_DIR_NVCC
+    NAMES nvcc nvcc.exe
+    PATHS
+      ENV CUDA_TOOLKIT_ROOT
+      ENV CUDA_PATH
+      ENV CUDA_BIN_PATH
+    PATH_SUFFIXES bin bin64
+    DOC "Toolkit location."
+    NO_DEFAULT_PATH
+    )
+
+  # Now search default paths
+  find_program(CUDA_TOOLKIT_ROOT_DIR_NVCC
+    NAMES nvcc nvcc.exe
+    PATHS /opt/cuda/bin
+    PATH_SUFFIXES cuda/bin
+    DOC "Toolkit location."
+    )
+
+  if (CUDA_TOOLKIT_ROOT_DIR_NVCC)
+    get_filename_component(CUDA_TOOLKIT_ROOT_DIR_NVCC_PAR "${CUDA_TOOLKIT_ROOT_DIR_NVCC}" DIRECTORY)
+    get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDA_TOOLKIT_ROOT_DIR_NVCC_PAR}" DIRECTORY CACHE)
+    string(REGEX REPLACE "[/\\\\]?bin[64]*[/\\\\]?$" "" CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT_DIR})
+    # We need to force this back into the cache.
+    set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT_DIR} CACHE PATH "Toolkit location." FORCE)
+    set(CUDA_TOOLKIT_TARGET_DIR ${CUDA_TOOLKIT_ROOT_DIR})
+  endif()
+  unset(CUDA_TOOLKIT_ROOT_DIR_NVCC CACHE)
+
+  if (NOT EXISTS ${CUDA_TOOLKIT_ROOT_DIR})
+    if(CUDA_FIND_REQUIRED)
+      message(FATAL_ERROR "Specify CUDA_TOOLKIT_ROOT_DIR")
+    elseif(NOT CUDA_FIND_QUIETLY)
+      message("CUDA_TOOLKIT_ROOT_DIR not found or specified")
+    endif()
+  endif ()
+endif ()
+
+if(CMAKE_CROSSCOMPILING)
+  SET (CUDA_TOOLKIT_ROOT $ENV{CUDA_TOOLKIT_ROOT})
+  if(CMAKE_SYSTEM_PROCESSOR STREQUAL "armv7-a")
+    # Support for NVPACK
+    set (CUDA_TOOLKIT_TARGET_NAMES "armv7-linux-androideabi")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm")
+    # Support for arm cross compilation
+    set(CUDA_TOOLKIT_TARGET_NAMES "armv7-linux-gnueabihf")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
+    # Support for aarch64 cross compilation
+    if (ANDROID_ARCH_NAME STREQUAL "arm64")
+      set(CUDA_TOOLKIT_TARGET_NAMES "aarch64-linux-androideabi")
+    else()
+      set(CUDA_TOOLKIT_TARGET_NAMES "aarch64-linux" "sbsa-linux")
+    endif (ANDROID_ARCH_NAME STREQUAL "arm64")
+  endif()
+
+  foreach(CUDA_TOOLKIT_TARGET_NAME IN LISTS CUDA_TOOLKIT_TARGET_NAMES)
+    if (EXISTS "${CUDA_TOOLKIT_ROOT}/targets/${CUDA_TOOLKIT_TARGET_NAME}")
+      set(CUDA_TOOLKIT_TARGET_DIR "${CUDA_TOOLKIT_ROOT}/targets/${CUDA_TOOLKIT_TARGET_NAME}" CACHE PATH "CUDA Toolkit target location.")
+      SET (CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT} CACHE PATH "Toolkit location." FORCE)
+      mark_as_advanced(CUDA_TOOLKIT_TARGET_DIR)
+      break()
+    endif()
+  endforeach()
+
+  # add known CUDA targetr root path to the set of directories we search for programs, libraries and headers
+  set( CMAKE_FIND_ROOT_PATH "${CUDA_TOOLKIT_TARGET_DIR};${CMAKE_FIND_ROOT_PATH}")
+  macro( cuda_find_host_program )
+    if (COMMAND find_host_program)
+      find_host_program( ${ARGN} )
+    else()
+      find_program( ${ARGN} )
+    endif()
+  endmacro()
+else()
+  # for non-cross-compile, find_host_program == find_program and CUDA_TOOLKIT_TARGET_DIR == CUDA_TOOLKIT_ROOT_DIR
+  macro( cuda_find_host_program )
+    find_program( ${ARGN} )
+  endmacro()
+  SET (CUDA_TOOLKIT_TARGET_DIR ${CUDA_TOOLKIT_ROOT_DIR})
+endif()
+
+
+# CUDA_NVCC_EXECUTABLE
+if(DEFINED ENV{CUDA_NVCC_EXECUTABLE})
+  set(CUDA_NVCC_EXECUTABLE "$ENV{CUDA_NVCC_EXECUTABLE}" CACHE FILEPATH "The CUDA compiler")
+else()
+  cuda_find_host_program(CUDA_NVCC_EXECUTABLE
+    NAMES nvcc
+    PATHS "${CUDA_TOOLKIT_ROOT_DIR}"
+    ENV CUDA_PATH
+    ENV CUDA_BIN_PATH
+    PATH_SUFFIXES bin bin64
+    NO_DEFAULT_PATH
+    )
+  # Search default search paths, after we search our own set of paths.
+  cuda_find_host_program(CUDA_NVCC_EXECUTABLE nvcc)
+endif()
+
+if(CUDA_NVCC_EXECUTABLE AND NOT CUDA_VERSION)
+  # Compute the version.
+  execute_process(COMMAND ${CUDA_NVCC_EXECUTABLE} "--version"
+    OUTPUT_VARIABLE NVCC_OUT
+    RESULT_VARIABLE NVCC_RC)
+  if(NOT (${NVCC_RC} EQUAL 0))
+    message(WARNING "Failed to execute '${CUDA_NVCC_EXECUTABLE} --version'")
+    set(CUDA_FOUND FALSE)
+    return()
+  endif()
+  string(REGEX REPLACE ".*release ([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR ${NVCC_OUT})
+  string(REGEX REPLACE ".*release ([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR ${NVCC_OUT})
+  set(CUDA_VERSION "${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}" CACHE STRING "Version of CUDA as computed from nvcc.")
+  mark_as_advanced(CUDA_VERSION)
+else()
+  # Need to set these based off of the cached value
+  string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR "${CUDA_VERSION}")
+  string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR "${CUDA_VERSION}")
+endif()
+
+# Always set this convenience variable
+set(CUDA_VERSION_STRING "${CUDA_VERSION}")
+
+# CUDA_TOOLKIT_INCLUDE
+find_path(CUDA_TOOLKIT_INCLUDE
+  device_functions.h # Header included in toolkit
+  PATHS ${CUDA_TOOLKIT_TARGET_DIR}
+  ENV CUDA_PATH
+  ENV CUDA_INC_PATH
+  PATH_SUFFIXES include
+  NO_DEFAULT_PATH
+  )
+# Search default search paths, after we search our own set of paths.
+find_path(CUDA_TOOLKIT_INCLUDE device_functions.h)
+mark_as_advanced(CUDA_TOOLKIT_INCLUDE)
+
+set(CUDA_HAS_FP16 TRUE)
+
+# Set the user list of include dir to nothing to initialize it.
+set (CUDA_NVCC_INCLUDE_DIRS_USER "")
+set (CUDA_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE})
+
+macro(cuda_find_library_local_first_with_path_ext _var _names _doc _path_ext )
+  if(CMAKE_SIZEOF_VOID_P EQUAL 8)
+    # CUDA 3.2+ on Windows moved the library directories, so we need the new
+    # and old paths.
+    set(_cuda_64bit_lib_dir "${_path_ext}lib/x64" "${_path_ext}lib64" "${_path_ext}libx64" )
+  endif()
+  # CUDA 3.2+ on Windows moved the library directories, so we need to new
+  # (lib/Win32) and the old path (lib).
+  find_library(${_var}
+    NAMES ${_names}
+    PATHS "${CUDA_TOOLKIT_TARGET_DIR}"
+    ENV CUDA_PATH
+    ENV CUDA_LIB_PATH
+    PATH_SUFFIXES ${_cuda_64bit_lib_dir} "${_path_ext}lib/Win32" "${_path_ext}lib" "${_path_ext}libWin32"
+    DOC ${_doc}
+    NO_DEFAULT_PATH
+    )
+  if (NOT CMAKE_CROSSCOMPILING)
+    # Search default search paths, after we search our own set of paths.
+    find_library(${_var}
+      NAMES ${_names}
+      PATHS "/usr/lib/nvidia-current"
+      DOC ${_doc}
+      )
+  endif()
+endmacro()
+
+macro(cuda_find_library_local_first _var _names _doc)
+  cuda_find_library_local_first_with_path_ext( "${_var}" "${_names}" "${_doc}" "" )
+endmacro()
+
+macro(find_library_local_first _var _names _doc )
+  cuda_find_library_local_first( "${_var}" "${_names}" "${_doc}" "" )
+endmacro()
+
+
+# CUDA_LIBRARIES
+cuda_find_library_local_first(CUDA_CUDART_LIBRARY cudart "\"cudart\" library")
+
+cuda_find_library_local_first(CUDA_cudart_static_LIBRARY cudart_static "static CUDA runtime library")
+mark_as_advanced(CUDA_cudart_static_LIBRARY)
+
+
+if(CUDA_cudart_static_LIBRARY)
+  # If static cudart available, use it by default, but provide a user-visible option to disable it.
+  option(CUDA_USE_STATIC_CUDA_RUNTIME "Use the static version of the CUDA runtime library if available" ON)
+else()
+  # If not available, silently disable the option.
+  set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "")
+endif()
+
+if(CUDA_USE_STATIC_CUDA_RUNTIME)
+  set(CUDA_CUDART_LIBRARY_VAR CUDA_cudart_static_LIBRARY)
+else()
+  set(CUDA_CUDART_LIBRARY_VAR CUDA_CUDART_LIBRARY)
+endif()
+
+cuda_find_library_local_first(CUDA_cudadevrt_LIBRARY cudadevrt "\"cudadevrt\" library")
+mark_as_advanced(CUDA_cudadevrt_LIBRARY)
+
+if(CUDA_USE_STATIC_CUDA_RUNTIME)
+  if(UNIX)
+    # Check for the dependent libraries.  Here we look for pthreads.
+    if (DEFINED CMAKE_THREAD_PREFER_PTHREAD)
+      set(_cuda_cmake_thread_prefer_pthread ${CMAKE_THREAD_PREFER_PTHREAD})
+    endif()
+    set(CMAKE_THREAD_PREFER_PTHREAD 1)
+
+    # Many of the FindXYZ CMake comes with makes use of try_compile with int main(){return 0;}
+    # as the source file.  Unfortunately this causes a warning with -Wstrict-prototypes and
+    # -Werror causes the try_compile to fail.  We will just temporarily disable other flags
+    # when doing the find_package command here.
+    set(_cuda_cmake_c_flags ${CMAKE_C_FLAGS})
+    set(CMAKE_C_FLAGS "-fPIC")
+    find_package(Threads REQUIRED)
+    set(CMAKE_C_FLAGS ${_cuda_cmake_c_flags})
+
+    if (DEFINED _cuda_cmake_thread_prefer_pthread)
+      set(CMAKE_THREAD_PREFER_PTHREAD ${_cuda_cmake_thread_prefer_pthread})
+      unset(_cuda_cmake_thread_prefer_pthread)
+    else()
+      unset(CMAKE_THREAD_PREFER_PTHREAD)
+    endif()
+
+    if(NOT APPLE)
+      #On Linux, you must link against librt when using the static cuda runtime.
+      find_library(CUDA_rt_LIBRARY rt)
+      if (NOT CUDA_rt_LIBRARY)
+        message(WARNING "Expecting to find librt for libcudart_static, but didn't find it.")
+      endif()
+    endif()
+  endif()
+endif()
+
+cuda_find_library_local_first_with_path_ext(CUDA_cupti_LIBRARY cupti "\"cupti\" library" "extras/CUPTI/")
+mark_as_advanced(CUDA_cupti_LIBRARY)
+
+# Set the CUDA_LIBRARIES variable.  This is the set of stuff to link against if you are
+# using the CUDA runtime.  For the dynamic version of the runtime, most of the
+# dependencies are brought in, but for the static version there are additional libraries
+# and linker commands needed.
+# Initialize to empty
+set(CUDA_LIBRARIES)
+
+# If we are using emulation mode and we found the cudartemu library then use
+# that one instead of cudart.
+if(CUDA_BUILD_EMULATION AND CUDA_CUDARTEMU_LIBRARY)
+  list(APPEND CUDA_LIBRARIES ${CUDA_CUDARTEMU_LIBRARY})
+elseif(CUDA_USE_STATIC_CUDA_RUNTIME AND CUDA_cudart_static_LIBRARY)
+  list(APPEND CUDA_LIBRARIES ${CUDA_cudart_static_LIBRARY} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS})
+  if (CUDA_rt_LIBRARY)
+    list(APPEND CUDA_LIBRARIES ${CUDA_rt_LIBRARY})
+  endif()
+  if(APPLE)
+    # We need to add the default path to the driver (libcuda.dylib) as an rpath, so that
+    # the static cuda runtime can find it at runtime.
+    list(APPEND CUDA_LIBRARIES -Wl,-rpath,/usr/local/cuda/lib)
+  endif()
+else()
+  list(APPEND CUDA_LIBRARIES ${CUDA_CUDART_LIBRARY})
+endif()
+
+# 1.1 toolkit on linux doesn't appear to have a separate library on
+# some platforms.
+cuda_find_library_local_first(CUDA_CUDA_LIBRARY cuda "\"cuda\" library (older versions only).")
+
+mark_as_advanced(
+  CUDA_CUDA_LIBRARY
+  CUDA_CUDART_LIBRARY
+  )
+
+#######################
+# Look for some of the toolkit helper libraries
+macro(FIND_CUDA_HELPER_LIBS _name)
+  cuda_find_library_local_first(CUDA_${_name}_LIBRARY ${_name} "\"${_name}\" library")
+  mark_as_advanced(CUDA_${_name}_LIBRARY)
+endmacro()
+
+if(CUDA_BUILD_EMULATION)
+  message(FATAL_ERROR "CUDA_BUILD_EMULATION is not supported in version 3.1 and onwards.  You must disable it to proceed.  You have version ${CUDA_VERSION}.")
+endif()
+
+find_cuda_helper_libs(cufft)
+find_cuda_helper_libs(cublas)
+find_cuda_helper_libs(cublasLt)
+# cusparse showed up in version 3.2
+find_cuda_helper_libs(cusparse)
+find_cuda_helper_libs(curand)
+if (WIN32)
+  find_cuda_helper_libs(nvcuvenc)
+  find_cuda_helper_libs(nvcuvid)
+endif()
+
+# In CUDA 9.0 NPP was nppi was removed
+find_cuda_helper_libs(nppc)
+find_cuda_helper_libs(nppial)
+find_cuda_helper_libs(nppicc)
+find_cuda_helper_libs(nppicom)
+find_cuda_helper_libs(nppidei)
+find_cuda_helper_libs(nppif)
+find_cuda_helper_libs(nppig)
+find_cuda_helper_libs(nppim)
+find_cuda_helper_libs(nppist)
+find_cuda_helper_libs(nppisu)
+find_cuda_helper_libs(nppitc)
+find_cuda_helper_libs(npps)
+set(CUDA_npp_LIBRARY "${CUDA_nppc_LIBRARY};${CUDA_nppial_LIBRARY};${CUDA_nppicc_LIBRARY};${CUDA_nppicom_LIBRARY};${CUDA_nppidei_LIBRARY};${CUDA_nppif_LIBRARY};${CUDA_nppig_LIBRARY};${CUDA_nppim_LIBRARY};${CUDA_nppist_LIBRARY};${CUDA_nppisu_LIBRARY};${CUDA_nppitc_LIBRARY};${CUDA_npps_LIBRARY}")
+# cusolver showed up in version 7.0
+find_cuda_helper_libs(cusolver)
+
+if (CUDA_BUILD_EMULATION)
+  set(CUDA_CUFFT_LIBRARIES ${CUDA_cufftemu_LIBRARY})
+  set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublasemu_LIBRARY})
+else()
+  set(CUDA_CUFFT_LIBRARIES ${CUDA_cufft_LIBRARY})
+  set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
+endif()
+
+########################
+# Look for the SDK stuff.  As of CUDA 3.0 NVSDKCUDA_ROOT has been replaced with
+# NVSDKCOMPUTE_ROOT with the old CUDA C contents moved into the C subdirectory
+find_path(CUDA_SDK_ROOT_DIR common/inc/cutil.h
+ HINTS
+  "$ENV{NVSDKCOMPUTE_ROOT}/C"
+  ENV NVSDKCUDA_ROOT
+  "[HKEY_LOCAL_MACHINE\\SOFTWARE\\NVIDIA Corporation\\Installed Products\\NVIDIA SDK 10\\Compute;InstallDir]"
+ PATHS
+  "/Developer/GPU\ Computing/C"
+  )
+
+# Keep the CUDA_SDK_ROOT_DIR first in order to be able to override the
+# environment variables.
+set(CUDA_SDK_SEARCH_PATH
+  "${CUDA_SDK_ROOT_DIR}"
+  "${CUDA_TOOLKIT_ROOT_DIR}/local/NVSDK0.2"
+  "${CUDA_TOOLKIT_ROOT_DIR}/NVSDK0.2"
+  "${CUDA_TOOLKIT_ROOT_DIR}/NV_CUDA_SDK"
+  "$ENV{HOME}/NVIDIA_CUDA_SDK"
+  "$ENV{HOME}/NVIDIA_CUDA_SDK_MACOSX"
+  "/Developer/CUDA"
+  )
+
+# Example of how to find an include file from the CUDA_SDK_ROOT_DIR
+
+# find_path(CUDA_CUT_INCLUDE_DIR
+#   cutil.h
+#   PATHS ${CUDA_SDK_SEARCH_PATH}
+#   PATH_SUFFIXES "common/inc"
+#   DOC "Location of cutil.h"
+#   NO_DEFAULT_PATH
+#   )
+# # Now search system paths
+# find_path(CUDA_CUT_INCLUDE_DIR cutil.h DOC "Location of cutil.h")
+
+# mark_as_advanced(CUDA_CUT_INCLUDE_DIR)
+
+
+# Example of how to find a library in the CUDA_SDK_ROOT_DIR
+
+# # cutil library is called cutil64 for 64 bit builds on windows.  We don't want
+# # to get these confused, so we are setting the name based on the word size of
+# # the build.
+
+# if(CMAKE_SIZEOF_VOID_P EQUAL 8)
+#   set(cuda_cutil_name cutil64)
+# else()
+#   set(cuda_cutil_name cutil32)
+# endif()
+
+# find_library(CUDA_CUT_LIBRARY
+#   NAMES cutil ${cuda_cutil_name}
+#   PATHS ${CUDA_SDK_SEARCH_PATH}
+#   # The new version of the sdk shows up in common/lib, but the old one is in lib
+#   PATH_SUFFIXES "common/lib" "lib"
+#   DOC "Location of cutil library"
+#   NO_DEFAULT_PATH
+#   )
+# # Now search system paths
+# find_library(CUDA_CUT_LIBRARY NAMES cutil ${cuda_cutil_name} DOC "Location of cutil library")
+# mark_as_advanced(CUDA_CUT_LIBRARY)
+# set(CUDA_CUT_LIBRARIES ${CUDA_CUT_LIBRARY})
+
+
+
+#############################
+# Check for required components
+set(CUDA_FOUND TRUE)
+
+set(CUDA_TOOLKIT_ROOT_DIR_INTERNAL "${CUDA_TOOLKIT_ROOT_DIR}" CACHE INTERNAL
+  "This is the value of the last time CUDA_TOOLKIT_ROOT_DIR was set successfully." FORCE)
+set(CUDA_TOOLKIT_TARGET_DIR_INTERNAL "${CUDA_TOOLKIT_TARGET_DIR}" CACHE INTERNAL
+  "This is the value of the last time CUDA_TOOLKIT_TARGET_DIR was set successfully." FORCE)
+set(CUDA_SDK_ROOT_DIR_INTERNAL "${CUDA_SDK_ROOT_DIR}" CACHE INTERNAL
+  "This is the value of the last time CUDA_SDK_ROOT_DIR was set successfully." FORCE)
+
+find_package_handle_standard_args(CUDA
+  REQUIRED_VARS
+    CUDA_TOOLKIT_ROOT_DIR
+    CUDA_NVCC_EXECUTABLE
+    CUDA_INCLUDE_DIRS
+    ${CUDA_CUDART_LIBRARY_VAR}
+  VERSION_VAR
+    CUDA_VERSION
+  )
+
+
+
+###############################################################################
+###############################################################################
+# Macros
+###############################################################################
+###############################################################################
+
+###############################################################################
+# Add include directories to pass to the nvcc command.
+macro(CUDA_INCLUDE_DIRECTORIES)
+  foreach(dir ${ARGN})
+    list(APPEND CUDA_NVCC_INCLUDE_DIRS_USER ${dir})
+  endforeach()
+endmacro()
+
+
+##############################################################################
+cuda_find_helper_file(parse_cubin cmake)
+cuda_find_helper_file(make2cmake cmake)
+cuda_find_helper_file(run_nvcc cmake)
+include("${CMAKE_CURRENT_LIST_DIR}/FindCUDA/select_compute_arch.cmake")
+
+##############################################################################
+# Separate the OPTIONS out from the sources
+#
+macro(CUDA_GET_SOURCES_AND_OPTIONS _sources _cmake_options _options)
+  set( ${_sources} )
+  set( ${_cmake_options} )
+  set( ${_options} )
+  set( _found_options FALSE )
+  foreach(arg ${ARGN})
+    if("x${arg}" STREQUAL "xOPTIONS")
+      set( _found_options TRUE )
+    elseif(
+        "x${arg}" STREQUAL "xWIN32" OR
+        "x${arg}" STREQUAL "xMACOSX_BUNDLE" OR
+        "x${arg}" STREQUAL "xEXCLUDE_FROM_ALL" OR
+        "x${arg}" STREQUAL "xSTATIC" OR
+        "x${arg}" STREQUAL "xSHARED" OR
+        "x${arg}" STREQUAL "xMODULE"
+        )
+      list(APPEND ${_cmake_options} ${arg})
+    else()
+      if ( _found_options )
+        list(APPEND ${_options} ${arg})
+      else()
+        # Assume this is a file
+        list(APPEND ${_sources} ${arg})
+      endif()
+    endif()
+  endforeach()
+endmacro()
+
+##############################################################################
+# Parse the OPTIONS from ARGN and set the variables prefixed by _option_prefix
+#
+macro(CUDA_PARSE_NVCC_OPTIONS _option_prefix)
+  set( _found_config )
+  foreach(arg ${ARGN})
+    # Determine if we are dealing with a perconfiguration flag
+    foreach(config ${CUDA_configuration_types})
+      string(TOUPPER ${config} config_upper)
+      if (arg STREQUAL "${config_upper}")
+        set( _found_config _${arg})
+        # Set arg to nothing to keep it from being processed further
+        set( arg )
+      endif()
+    endforeach()
+
+    if ( arg )
+      list(APPEND ${_option_prefix}${_found_config} "${arg}")
+    endif()
+  endforeach()
+endmacro()
+
+##############################################################################
+# Helper to add the include directory for CUDA only once
+function(CUDA_ADD_CUDA_INCLUDE_ONCE)
+  get_directory_property(_include_directories INCLUDE_DIRECTORIES)
+  set(_add TRUE)
+  if(_include_directories)
+    foreach(dir ${_include_directories})
+      if("${dir}" STREQUAL "${CUDA_INCLUDE_DIRS}")
+        set(_add FALSE)
+      endif()
+    endforeach()
+  endif()
+  if(_add)
+    include_directories(${CUDA_INCLUDE_DIRS})
+  endif()
+endfunction()
+
+function(CUDA_BUILD_SHARED_LIBRARY shared_flag)
+  set(cmake_args ${ARGN})
+  # If SHARED, MODULE, or STATIC aren't already in the list of arguments, then
+  # add SHARED or STATIC based on the value of BUILD_SHARED_LIBS.
+  list(FIND cmake_args SHARED _cuda_found_SHARED)
+  list(FIND cmake_args MODULE _cuda_found_MODULE)
+  list(FIND cmake_args STATIC _cuda_found_STATIC)
+  if( _cuda_found_SHARED GREATER -1 OR
+      _cuda_found_MODULE GREATER -1 OR
+      _cuda_found_STATIC GREATER -1)
+    set(_cuda_build_shared_libs)
+  else()
+    if (BUILD_SHARED_LIBS)
+      set(_cuda_build_shared_libs SHARED)
+    else()
+      set(_cuda_build_shared_libs STATIC)
+    endif()
+  endif()
+  set(${shared_flag} ${_cuda_build_shared_libs} PARENT_SCOPE)
+endfunction()
+
+##############################################################################
+# Helper to avoid clashes of files with the same basename but different paths.
+# This doesn't attempt to do exactly what CMake internals do, which is to only
+# add this path when there is a conflict, since by the time a second collision
+# in names is detected it's already too late to fix the first one.  For
+# consistency sake the relative path will be added to all files.
+function(CUDA_COMPUTE_BUILD_PATH path build_path)
+  #message("CUDA_COMPUTE_BUILD_PATH([${path}] ${build_path})")
+  # Only deal with CMake style paths from here on out
+  file(TO_CMAKE_PATH "${path}" bpath)
+  if (IS_ABSOLUTE "${bpath}")
+    # Absolute paths are generally unnecessary, especially if something like
+    # file(GLOB_RECURSE) is used to pick up the files.
+
+    string(FIND "${bpath}" "${CMAKE_CURRENT_BINARY_DIR}" _binary_dir_pos)
+    if (_binary_dir_pos EQUAL 0)
+      file(RELATIVE_PATH bpath "${CMAKE_CURRENT_BINARY_DIR}" "${bpath}")
+    else()
+      file(RELATIVE_PATH bpath "${CMAKE_CURRENT_SOURCE_DIR}" "${bpath}")
+    endif()
+  endif()
+
+  # This recipe is from cmLocalGenerator::CreateSafeUniqueObjectFileName in the
+  # CMake source.
+
+  # Remove leading /
+  string(REGEX REPLACE "^[/]+" "" bpath "${bpath}")
+  # Avoid absolute paths by removing ':'
+  string(REPLACE ":" "_" bpath "${bpath}")
+  # Avoid relative paths that go up the tree
+  string(REPLACE "../" "__/" bpath "${bpath}")
+  # Avoid spaces
+  string(REPLACE " " "_" bpath "${bpath}")
+
+  # Strip off the filename.  I wait until here to do it, since removing the
+  # basename can make a path that looked like path/../basename turn into
+  # path/.. (notice the trailing slash).
+  get_filename_component(bpath "${bpath}" PATH)
+
+  set(${build_path} "${bpath}" PARENT_SCOPE)
+  #message("${build_path} = ${bpath}")
+endfunction()
+
+##############################################################################
+# This helper macro populates the following variables and setups up custom
+# commands and targets to invoke the nvcc compiler to generate C or PTX source
+# dependent upon the format parameter.  The compiler is invoked once with -M
+# to generate a dependency file and a second time with -cuda or -ptx to generate
+# a .cpp or .ptx file.
+# INPUT:
+#   cuda_target         - Target name
+#   format              - PTX, CUBIN, FATBIN or OBJ
+#   FILE1 .. FILEN      - The remaining arguments are the sources to be wrapped.
+#   OPTIONS             - Extra options to NVCC
+# OUTPUT:
+#   generated_files     - List of generated files
+##############################################################################
+##############################################################################
+
+macro(CUDA_WRAP_SRCS cuda_target format generated_files)
+
+  # Put optional arguments in list.
+  set(_argn_list "${ARGN}")
+  # If one of the given optional arguments is "PHONY", make a note of it, then
+  # remove it from the list.
+  list(FIND _argn_list "PHONY" _phony_idx)
+  if("${_phony_idx}" GREATER "-1")
+    set(_target_is_phony true)
+    list(REMOVE_AT _argn_list ${_phony_idx})
+  else()
+    set(_target_is_phony false)
+  endif()
+
+  # If CMake doesn't support separable compilation, complain
+  if(CUDA_SEPARABLE_COMPILATION AND CMAKE_VERSION VERSION_LESS "2.8.10.1")
+    message(SEND_ERROR "CUDA_SEPARABLE_COMPILATION isn't supported for CMake versions less than 2.8.10.1")
+  endif()
+
+  # Set up all the command line flags here, so that they can be overridden on a per target basis.
+
+  set(nvcc_flags "")
+
+  # Emulation if the card isn't present.
+  if (CUDA_BUILD_EMULATION)
+    # Emulation.
+    set(nvcc_flags ${nvcc_flags} --device-emulation -D_DEVICEEMU -g)
+  else()
+    # Device mode.  No flags necessary.
+  endif()
+
+  if(CUDA_HOST_COMPILATION_CPP)
+    set(CUDA_C_OR_CXX CXX)
+  else()
+    message(WARNING "--host-compilation flag is deprecated in CUDA version >= 3.0.  Removing --host-compilation C flag" )
+    set(CUDA_C_OR_CXX C)
+  endif()
+
+  set(generated_extension ${CMAKE_${CUDA_C_OR_CXX}_OUTPUT_EXTENSION})
+
+  if(CUDA_64_BIT_DEVICE_CODE)
+    set(nvcc_flags ${nvcc_flags} -m64)
+  else()
+    set(nvcc_flags ${nvcc_flags} -m32)
+  endif()
+
+  if(CUDA_TARGET_CPU_ARCH)
+    set(nvcc_flags ${nvcc_flags} "--target-cpu-architecture=${CUDA_TARGET_CPU_ARCH}")
+  endif()
+
+  # This needs to be passed in at this stage, because VS needs to fill out the
+  # various macros from within VS.  Note that CCBIN is only used if
+  # -ccbin or --compiler-bindir isn't used and CUDA_HOST_COMPILER matches
+  # _CUDA_MSVC_HOST_COMPILER
+  if(CMAKE_GENERATOR MATCHES "Visual Studio")
+    set(ccbin_flags -D "\"CCBIN:PATH=${_CUDA_MSVC_HOST_COMPILER}\"" )
+  else()
+    set(ccbin_flags)
+  endif()
+
+  # Figure out which configure we will use and pass that in as an argument to
+  # the script.  We need to defer the decision until compilation time, because
+  # for VS projects we won't know if we are making a debug or release build
+  # until build time.
+  if(CMAKE_GENERATOR MATCHES "Visual Studio")
+    set( CUDA_build_configuration "$(ConfigurationName)" )
+  else()
+    set( CUDA_build_configuration "${CMAKE_BUILD_TYPE}")
+  endif()
+
+  # Initialize our list of includes with the user ones followed by the CUDA system ones.
+  set(CUDA_NVCC_INCLUDE_DIRS ${CUDA_NVCC_INCLUDE_DIRS_USER} "${CUDA_INCLUDE_DIRS}")
+  if(_target_is_phony)
+    # If the passed in target name isn't a real target (i.e., this is from a call to one of the
+    # cuda_compile_* functions), need to query directory properties to get include directories
+    # and compile definitions.
+    get_directory_property(_dir_include_dirs INCLUDE_DIRECTORIES)
+    get_directory_property(_dir_compile_defs COMPILE_DEFINITIONS)
+
+    list(APPEND CUDA_NVCC_INCLUDE_DIRS "${_dir_include_dirs}")
+    set(CUDA_NVCC_COMPILE_DEFINITIONS "${_dir_compile_defs}")
+  else()
+    # Append the include directories for this target via generator expression, which is
+    # expanded by the FILE(GENERATE) call below.  This generator expression captures all
+    # include dirs set by the user, whether via directory properties or target properties
+    list(APPEND CUDA_NVCC_INCLUDE_DIRS "$")
+
+    # Do the same thing with compile definitions
+    set(CUDA_NVCC_COMPILE_DEFINITIONS "$")
+  endif()
+
+
+  # Reset these variables
+  set(CUDA_WRAP_OPTION_NVCC_FLAGS)
+  foreach(config ${CUDA_configuration_types})
+    string(TOUPPER ${config} config_upper)
+    set(CUDA_WRAP_OPTION_NVCC_FLAGS_${config_upper})
+  endforeach()
+
+  CUDA_GET_SOURCES_AND_OPTIONS(_cuda_wrap_sources _cuda_wrap_cmake_options _cuda_wrap_options ${_argn_list})
+  CUDA_PARSE_NVCC_OPTIONS(CUDA_WRAP_OPTION_NVCC_FLAGS ${_cuda_wrap_options})
+
+  # Figure out if we are building a shared library.  BUILD_SHARED_LIBS is
+  # respected in CUDA_ADD_LIBRARY.
+  set(_cuda_build_shared_libs FALSE)
+  # SHARED, MODULE
+  list(FIND _cuda_wrap_cmake_options SHARED _cuda_found_SHARED)
+  list(FIND _cuda_wrap_cmake_options MODULE _cuda_found_MODULE)
+  if(_cuda_found_SHARED GREATER -1 OR _cuda_found_MODULE GREATER -1)
+    set(_cuda_build_shared_libs TRUE)
+  endif()
+  # STATIC
+  list(FIND _cuda_wrap_cmake_options STATIC _cuda_found_STATIC)
+  if(_cuda_found_STATIC GREATER -1)
+    set(_cuda_build_shared_libs FALSE)
+  endif()
+
+  # CUDA_HOST_FLAGS
+  if(_cuda_build_shared_libs)
+    # If we are setting up code for a shared library, then we need to add extra flags for
+    # compiling objects for shared libraries.
+    set(CUDA_HOST_SHARED_FLAGS ${CMAKE_SHARED_LIBRARY_${CUDA_C_OR_CXX}_FLAGS})
+  else()
+    set(CUDA_HOST_SHARED_FLAGS)
+  endif()
+
+  macro(_filter_blocklisted_host_flags CUDA_FLAGS)
+    string(REGEX REPLACE "[ \t]+" ";" ${CUDA_FLAGS} "${${CUDA_FLAGS}}")
+    foreach(_blacklisted ${CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST})
+      list(REMOVE_ITEM ${CUDA_FLAGS} "${_blacklisted}")
+    endforeach()
+    string(REPLACE ";" " " ${CUDA_FLAGS} "${${CUDA_FLAGS}}")
+  endmacro()
+
+  # Only add the CMAKE_{C,CXX}_FLAGS if we are propagating host flags.  We
+  # always need to set the SHARED_FLAGS, though.
+  if(CUDA_PROPAGATE_HOST_FLAGS)
+    set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS}")
+    _filter_blocklisted_host_flags(_cuda_C_FLAGS)
+    set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${_cuda_C_FLAGS} ${CUDA_HOST_SHARED_FLAGS})")
+  else()
+    set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${CUDA_HOST_SHARED_FLAGS})")
+  endif()
+
+  set(_cuda_nvcc_flags_config "# Build specific configuration flags")
+  # Loop over all the configuration types to generate appropriate flags for run_nvcc.cmake
+  foreach(config ${CUDA_configuration_types})
+    string(TOUPPER ${config} config_upper)
+    # CMAKE_FLAGS are strings and not lists.  By not putting quotes around CMAKE_FLAGS
+    # we convert the strings to lists (like we want).
+
+    if(CUDA_PROPAGATE_HOST_FLAGS)
+      # nvcc chokes on -g3 in versions previous to 3.0, so replace it with -g
+      set(_cuda_fix_g3 FALSE)
+
+      set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS_${config_upper}}")
+      _filter_blocklisted_host_flags(_cuda_C_FLAGS)
+      if(_cuda_fix_g3)
+        string(REPLACE "-g3" "-g" _cuda_C_FLAGS "${_cuda_C_FLAGS}")
+      endif()
+
+      string(APPEND _cuda_host_flags "\nset(CMAKE_HOST_FLAGS_${config_upper} ${_cuda_C_FLAGS})")
+    endif()
+
+    # Note that if we ever want CUDA_NVCC_FLAGS_ to be string (instead of a list
+    # like it is currently), we can remove the quotes around the
+    # ${CUDA_NVCC_FLAGS_${config_upper}} variable like the CMAKE_HOST_FLAGS_ variable.
+    string(APPEND _cuda_nvcc_flags_config "\nset(CUDA_NVCC_FLAGS_${config_upper} ${CUDA_NVCC_FLAGS_${config_upper}} ;; ${CUDA_WRAP_OPTION_NVCC_FLAGS_${config_upper}})")
+  endforeach()
+
+  # Process the C++14 flag.  If the host sets the flag, we need to add it to nvcc and
+  # remove it from the host. This is because -Xcompile -std=c++ will choke nvcc (it uses
+  # the C preprocessor).  In order to get this to work correctly, we need to use nvcc's
+  # specific c++14 flag.
+  if( "${_cuda_host_flags}" MATCHES "-std=c\\+\\+11")
+    # Add the c++14 flag to nvcc if it isn't already present.  Note that we only look at
+    # the main flag instead of the configuration specific flags.
+    if( NOT "${CUDA_NVCC_FLAGS}" MATCHES "-std=c\\+\\+14" )
+      list(APPEND nvcc_flags --std c++14)
+    endif()
+    string(REGEX REPLACE "[-]+std=c\\+\\+14" "" _cuda_host_flags "${_cuda_host_flags}")
+  endif()
+
+  if(_cuda_build_shared_libs)
+    list(APPEND nvcc_flags "-D${cuda_target}_EXPORTS")
+  endif()
+
+  # Reset the output variable
+  set(_cuda_wrap_generated_files "")
+
+  # Iterate over the macro arguments and create custom
+  # commands for all the .cu files.
+  foreach(file ${_argn_list})
+    # Ignore any file marked as a HEADER_FILE_ONLY
+    get_source_file_property(_is_header ${file} HEADER_FILE_ONLY)
+    # Allow per source file overrides of the format.  Also allows compiling non-.cu files.
+    get_source_file_property(_cuda_source_format ${file} CUDA_SOURCE_PROPERTY_FORMAT)
+    if((${file} MATCHES "\\.cu$" OR _cuda_source_format) AND NOT _is_header)
+
+      if(NOT _cuda_source_format)
+        set(_cuda_source_format ${format})
+      endif()
+      # If file isn't a .cu file, we need to tell nvcc to treat it as such.
+      if(NOT file MATCHES "\\.cu$")
+        set(cuda_language_flag -x=cu)
+      else()
+        set(cuda_language_flag)
+      endif()
+
+      if( ${_cuda_source_format} MATCHES "OBJ")
+        set( cuda_compile_to_external_module OFF )
+      else()
+        set( cuda_compile_to_external_module ON )
+        if( ${_cuda_source_format} MATCHES "PTX" )
+          set( cuda_compile_to_external_module_type "ptx" )
+        elseif( ${_cuda_source_format} MATCHES "CUBIN")
+          set( cuda_compile_to_external_module_type "cubin" )
+        elseif( ${_cuda_source_format} MATCHES "FATBIN")
+          set( cuda_compile_to_external_module_type "fatbin" )
+        else()
+          message( FATAL_ERROR "Invalid format flag passed to CUDA_WRAP_SRCS or set with CUDA_SOURCE_PROPERTY_FORMAT file property for file '${file}': '${_cuda_source_format}'.  Use OBJ, PTX, CUBIN or FATBIN.")
+        endif()
+      endif()
+
+      if(cuda_compile_to_external_module)
+        # Don't use any of the host compilation flags for PTX targets.
+        set(CUDA_HOST_FLAGS)
+        set(CUDA_NVCC_FLAGS_CONFIG)
+      else()
+        set(CUDA_HOST_FLAGS ${_cuda_host_flags})
+        set(CUDA_NVCC_FLAGS_CONFIG ${_cuda_nvcc_flags_config})
+      endif()
+
+      # Determine output directory
+      cuda_compute_build_path("${file}" cuda_build_path)
+      set(cuda_compile_intermediate_directory "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${cuda_target}.dir/${cuda_build_path}")
+      if(CUDA_GENERATED_OUTPUT_DIR)
+        set(cuda_compile_output_dir "${CUDA_GENERATED_OUTPUT_DIR}")
+      else()
+        if ( cuda_compile_to_external_module )
+          set(cuda_compile_output_dir "${CMAKE_CURRENT_BINARY_DIR}")
+        else()
+          set(cuda_compile_output_dir "${cuda_compile_intermediate_directory}")
+        endif()
+      endif()
+
+      # Add a custom target to generate a c or ptx file. ######################
+
+      get_filename_component( basename ${file} NAME )
+      if( cuda_compile_to_external_module )
+        set(generated_file_path "${cuda_compile_output_dir}")
+        set(generated_file_basename "${cuda_target}_generated_${basename}.${cuda_compile_to_external_module_type}")
+        set(format_flag "-${cuda_compile_to_external_module_type}")
+        file(MAKE_DIRECTORY "${cuda_compile_output_dir}")
+      else()
+        set(generated_file_path "${cuda_compile_output_dir}/${CMAKE_CFG_INTDIR}")
+        set(generated_file_basename "${cuda_target}_generated_${basename}${generated_extension}")
+        if(CUDA_SEPARABLE_COMPILATION)
+          set(format_flag "-dc")
+        else()
+          set(format_flag "-c")
+        endif()
+      endif()
+
+      # Set all of our file names.  Make sure that whatever filenames that have
+      # generated_file_path in them get passed in through as a command line
+      # argument, so that the ${CMAKE_CFG_INTDIR} gets expanded at run time
+      # instead of configure time.
+      set(generated_file "${generated_file_path}/${generated_file_basename}")
+      set(cmake_dependency_file "${cuda_compile_intermediate_directory}/${generated_file_basename}.depend")
+      set(NVCC_generated_dependency_file "${cuda_compile_intermediate_directory}/${generated_file_basename}.NVCC-depend")
+      set(generated_cubin_file "${generated_file_path}/${generated_file_basename}.cubin.txt")
+      set(custom_target_script_pregen "${cuda_compile_intermediate_directory}/${generated_file_basename}.cmake.pre-gen")
+      set(custom_target_script "${cuda_compile_intermediate_directory}/${generated_file_basename}$<$>:.$>.cmake")
+
+      # Setup properties for obj files:
+      if( NOT cuda_compile_to_external_module )
+        set_source_files_properties("${generated_file}"
+          PROPERTIES
+          EXTERNAL_OBJECT true # This is an object file not to be compiled, but only be linked.
+          )
+      endif()
+
+      # Don't add CMAKE_CURRENT_SOURCE_DIR if the path is already an absolute path.
+      get_filename_component(file_path "${file}" PATH)
+      if(IS_ABSOLUTE "${file_path}")
+        set(source_file "${file}")
+      else()
+        set(source_file "${CMAKE_CURRENT_SOURCE_DIR}/${file}")
+      endif()
+
+      if( NOT cuda_compile_to_external_module AND CUDA_SEPARABLE_COMPILATION)
+        list(APPEND ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS "${generated_file}")
+      endif()
+
+      # Bring in the dependencies.  Creates a variable CUDA_NVCC_DEPEND #######
+      cuda_include_nvcc_dependencies(${cmake_dependency_file})
+
+      # Convenience string for output #########################################
+      if(CUDA_BUILD_EMULATION)
+        set(cuda_build_type "Emulation")
+      else()
+        set(cuda_build_type "Device")
+      endif()
+
+      # Build the NVCC made dependency file ###################################
+      set(build_cubin OFF)
+      if ( NOT CUDA_BUILD_EMULATION AND CUDA_BUILD_CUBIN )
+         if ( NOT cuda_compile_to_external_module )
+           set ( build_cubin ON )
+         endif()
+      endif()
+
+      # Configure the build script
+      configure_file("${CUDA_run_nvcc}" "${custom_target_script_pregen}" @ONLY)
+      file(GENERATE
+        OUTPUT "${custom_target_script}"
+        INPUT "${custom_target_script_pregen}"
+        )
+
+      # So if a user specifies the same cuda file as input more than once, you
+      # can have bad things happen with dependencies.  Here we check an option
+      # to see if this is the behavior they want.
+      if(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE)
+        set(main_dep MAIN_DEPENDENCY ${source_file})
+      else()
+        set(main_dep DEPENDS ${source_file})
+      endif()
+
+      if(CUDA_VERBOSE_BUILD)
+        set(verbose_output ON)
+      elseif(CMAKE_GENERATOR MATCHES "Makefiles")
+        set(verbose_output "$(VERBOSE)")
+      # This condition lets us also turn on verbose output when someone
+      # specifies CMAKE_VERBOSE_MAKEFILE, even if the generator isn't
+      # the Makefiles generator (this is important for us, Ninja users.)
+      elseif(CMAKE_VERBOSE_MAKEFILE)
+        set(verbose_output ON)
+      else()
+        set(verbose_output OFF)
+      endif()
+
+      # Create up the comment string
+      file(RELATIVE_PATH generated_file_relative_path "${CMAKE_BINARY_DIR}" "${generated_file}")
+      if(cuda_compile_to_external_module)
+        set(cuda_build_comment_string "Building NVCC ${cuda_compile_to_external_module_type} file ${generated_file_relative_path}")
+      else()
+        set(cuda_build_comment_string "Building NVCC (${cuda_build_type}) object ${generated_file_relative_path}")
+      endif()
+
+      set(_verbatim VERBATIM)
+      if(ccbin_flags MATCHES "\\$\\(VCInstallDir\\)")
+        set(_verbatim "")
+      endif()
+
+      # Build the generated file and dependency file ##########################
+      add_custom_command(
+        OUTPUT ${generated_file}
+        # These output files depend on the source_file and the contents of cmake_dependency_file
+        ${main_dep}
+        DEPENDS ${CUDA_NVCC_DEPEND}
+        DEPENDS ${custom_target_script}
+        # Make sure the output directory exists before trying to write to it.
+        COMMAND ${CMAKE_COMMAND} -E make_directory "${generated_file_path}"
+        COMMAND ${CMAKE_COMMAND} ARGS
+          -D verbose:BOOL=${verbose_output}
+          ${ccbin_flags}
+          -D build_configuration:STRING=${CUDA_build_configuration}
+          -D "generated_file:STRING=${generated_file}"
+          -D "generated_cubin_file:STRING=${generated_cubin_file}"
+          -P "${custom_target_script}"
+        WORKING_DIRECTORY "${cuda_compile_intermediate_directory}"
+        COMMENT "${cuda_build_comment_string}"
+        ${_verbatim}
+        )
+
+      # Make sure the build system knows the file is generated.
+      set_source_files_properties(${generated_file} PROPERTIES GENERATED TRUE)
+
+      list(APPEND _cuda_wrap_generated_files ${generated_file})
+
+      # Add the other files that we want cmake to clean on a cleanup ##########
+      list(APPEND CUDA_ADDITIONAL_CLEAN_FILES "${cmake_dependency_file}")
+      list(REMOVE_DUPLICATES CUDA_ADDITIONAL_CLEAN_FILES)
+      set(CUDA_ADDITIONAL_CLEAN_FILES ${CUDA_ADDITIONAL_CLEAN_FILES} CACHE INTERNAL "List of intermediate files that are part of the cuda dependency scanning.")
+
+    endif()
+  endforeach()
+
+  # Set the return parameter
+  set(${generated_files} ${_cuda_wrap_generated_files})
+endmacro()
+
+function(_cuda_get_important_host_flags important_flags flag_string)
+  if(CMAKE_GENERATOR MATCHES "Visual Studio")
+    string(REGEX MATCHALL "/M[DT][d]?" flags "${flag_string}")
+    list(APPEND ${important_flags} ${flags})
+  else()
+    string(REGEX MATCHALL "-fPIC" flags "${flag_string}")
+    list(APPEND ${important_flags} ${flags})
+  endif()
+  set(${important_flags} ${${important_flags}} PARENT_SCOPE)
+endfunction()
+
+###############################################################################
+###############################################################################
+# Separable Compilation Link
+###############################################################################
+###############################################################################
+
+# Compute the filename to be used by CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS
+function(CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME output_file_var cuda_target object_files)
+  if (object_files)
+    set(generated_extension ${CMAKE_${CUDA_C_OR_CXX}_OUTPUT_EXTENSION})
+    set(output_file "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${cuda_target}.dir/${CMAKE_CFG_INTDIR}/${cuda_target}_intermediate_link${generated_extension}")
+  else()
+    set(output_file)
+  endif()
+
+  set(${output_file_var} "${output_file}" PARENT_SCOPE)
+endfunction()
+
+# Setup the build rule for the separable compilation intermediate link file.
+function(CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS output_file cuda_target options object_files)
+  if (object_files)
+
+    set_source_files_properties("${output_file}"
+      PROPERTIES
+      EXTERNAL_OBJECT TRUE # This is an object file not to be compiled, but only
+                           # be linked.
+      GENERATED TRUE       # This file is generated during the build
+      )
+
+    # For now we are ignoring all the configuration specific flags.
+    set(nvcc_flags)
+    CUDA_PARSE_NVCC_OPTIONS(nvcc_flags ${options})
+    if(CUDA_64_BIT_DEVICE_CODE)
+      list(APPEND nvcc_flags -m64)
+    else()
+      list(APPEND nvcc_flags -m32)
+    endif()
+    # If -ccbin, --compiler-bindir has been specified, don't do anything.  Otherwise add it here.
+    list( FIND nvcc_flags "-ccbin" ccbin_found0 )
+    list( FIND nvcc_flags "--compiler-bindir" ccbin_found1 )
+    if( ccbin_found0 LESS 0 AND ccbin_found1 LESS 0 AND CUDA_HOST_COMPILER )
+      # Match VERBATIM check below.
+      if(CUDA_HOST_COMPILER MATCHES "\\$\\(VCInstallDir\\)")
+        list(APPEND nvcc_flags -ccbin "\"${CUDA_HOST_COMPILER}\"")
+      else()
+        list(APPEND nvcc_flags -ccbin "${CUDA_HOST_COMPILER}")
+      endif()
+    endif()
+
+    # Create a list of flags specified by CUDA_NVCC_FLAGS_${CONFIG} and CMAKE_${CUDA_C_OR_CXX}_FLAGS*
+    set(config_specific_flags)
+    set(flags)
+    foreach(config ${CUDA_configuration_types})
+      string(TOUPPER ${config} config_upper)
+      # Add config specific flags
+      foreach(f ${CUDA_NVCC_FLAGS_${config_upper}})
+        list(APPEND config_specific_flags $<$:${f}>)
+      endforeach()
+      set(important_host_flags)
+      _cuda_get_important_host_flags(important_host_flags "${CMAKE_${CUDA_C_OR_CXX}_FLAGS_${config_upper}}")
+      foreach(f ${important_host_flags})
+        list(APPEND flags $<$:-Xcompiler> $<$:${f}>)
+      endforeach()
+    endforeach()
+    # Add CMAKE_${CUDA_C_OR_CXX}_FLAGS
+    set(important_host_flags)
+    _cuda_get_important_host_flags(important_host_flags "${CMAKE_${CUDA_C_OR_CXX}_FLAGS}")
+    foreach(f ${important_host_flags})
+      list(APPEND flags -Xcompiler ${f})
+    endforeach()
+
+    # Add our general CUDA_NVCC_FLAGS with the configuration specific flags
+    set(nvcc_flags ${CUDA_NVCC_FLAGS} ${config_specific_flags} ${nvcc_flags})
+
+    file(RELATIVE_PATH output_file_relative_path "${CMAKE_BINARY_DIR}" "${output_file}")
+
+    # Some generators don't handle the multiple levels of custom command
+    # dependencies correctly (obj1 depends on file1, obj2 depends on obj1), so
+    # we work around that issue by compiling the intermediate link object as a
+    # pre-link custom command in that situation.
+    set(do_obj_build_rule TRUE)
+    if (MSVC_VERSION GREATER 1599 AND MSVC_VERSION LESS 1800)
+      # VS 2010 and 2012 have this problem.
+      set(do_obj_build_rule FALSE)
+    endif()
+
+    set(_verbatim VERBATIM)
+    if(nvcc_flags MATCHES "\\$\\(VCInstallDir\\)")
+      set(_verbatim "")
+    endif()
+
+    if (do_obj_build_rule)
+      add_custom_command(
+        OUTPUT ${output_file}
+        DEPENDS ${object_files}
+        COMMAND ${CUDA_NVCC_EXECUTABLE} ${nvcc_flags} -dlink ${object_files} -o ${output_file}
+        ${flags}
+        COMMENT "Building NVCC intermediate link file ${output_file_relative_path}"
+        COMMAND_EXPAND_LISTS
+        ${_verbatim}
+        )
+    else()
+      get_filename_component(output_file_dir "${output_file}" DIRECTORY)
+      add_custom_command(
+        TARGET ${cuda_target}
+        PRE_LINK
+        COMMAND ${CMAKE_COMMAND} -E echo "Building NVCC intermediate link file ${output_file_relative_path}"
+        COMMAND ${CMAKE_COMMAND} -E make_directory "${output_file_dir}"
+        COMMAND ${CUDA_NVCC_EXECUTABLE} ${nvcc_flags} ${flags} -dlink ${object_files} -o "${output_file}"
+        COMMAND_EXPAND_LISTS
+        ${_verbatim}
+        )
+    endif()
+ endif()
+endfunction()
+
+###############################################################################
+###############################################################################
+# ADD LIBRARY
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_LIBRARY cuda_target)
+
+  CUDA_ADD_CUDA_INCLUDE_ONCE()
+
+  # Separate the sources from the options
+  CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN})
+  CUDA_BUILD_SHARED_LIBRARY(_cuda_shared_flag ${ARGN})
+  # Create custom commands and targets for each file.
+  CUDA_WRAP_SRCS( ${cuda_target} OBJ _generated_files ${_sources}
+    ${_cmake_options} ${_cuda_shared_flag}
+    OPTIONS ${_options} )
+
+  # Compute the file name of the intermedate link file used for separable
+  # compilation.
+  CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME(link_file ${cuda_target} "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  # Add the library.
+  add_library(${cuda_target} ${_cmake_options}
+    ${_generated_files}
+    ${_sources}
+    ${link_file}
+    )
+
+  # Add a link phase for the separable compilation if it has been enabled.  If
+  # it has been enabled then the ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS
+  # variable will have been defined.
+  CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS("${link_file}" ${cuda_target} "${_options}" "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD}
+    ${CUDA_LIBRARIES}
+    )
+
+  if(CUDA_SEPARABLE_COMPILATION)
+    target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD}
+      ${CUDA_cudadevrt_LIBRARY}
+      )
+  endif()
+
+  # We need to set the linker language based on what the expected generated file
+  # would be. CUDA_C_OR_CXX is computed based on CUDA_HOST_COMPILATION_CPP.
+  set_target_properties(${cuda_target}
+    PROPERTIES
+    LINKER_LANGUAGE ${CUDA_C_OR_CXX}
+    )
+
+endmacro()
+
+
+###############################################################################
+###############################################################################
+# ADD EXECUTABLE
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_EXECUTABLE cuda_target)
+
+  CUDA_ADD_CUDA_INCLUDE_ONCE()
+
+  # Separate the sources from the options
+  CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN})
+  # Create custom commands and targets for each file.
+  CUDA_WRAP_SRCS( ${cuda_target} OBJ _generated_files ${_sources} OPTIONS ${_options} )
+
+  # Compute the file name of the intermedate link file used for separable
+  # compilation.
+  CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME(link_file ${cuda_target} "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  # Add the library.
+  add_executable(${cuda_target} ${_cmake_options}
+    ${_generated_files}
+    ${_sources}
+    ${link_file}
+    )
+
+  # Add a link phase for the separable compilation if it has been enabled.  If
+  # it has been enabled then the ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS
+  # variable will have been defined.
+  CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS("${link_file}" ${cuda_target} "${_options}" "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD}
+    ${CUDA_LIBRARIES}
+    )
+
+  # We need to set the linker language based on what the expected generated file
+  # would be. CUDA_C_OR_CXX is computed based on CUDA_HOST_COMPILATION_CPP.
+  set_target_properties(${cuda_target}
+    PROPERTIES
+    LINKER_LANGUAGE ${CUDA_C_OR_CXX}
+    )
+
+endmacro()
+
+
+###############################################################################
+###############################################################################
+# (Internal) helper for manually added cuda source files with specific targets
+###############################################################################
+###############################################################################
+macro(cuda_compile_base cuda_target format generated_files)
+  # Update a counter in this directory, to keep phony target names unique.
+  set(_cuda_target "${cuda_target}")
+  get_property(_counter DIRECTORY PROPERTY _cuda_internal_phony_counter)
+  if(_counter)
+    math(EXPR _counter "${_counter} + 1")
+  else()
+    set(_counter 1)
+  endif()
+  string(APPEND _cuda_target "_${_counter}")
+  set_property(DIRECTORY PROPERTY _cuda_internal_phony_counter ${_counter})
+
+  # Separate the sources from the options
+  CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN})
+
+  # Create custom commands and targets for each file.
+  CUDA_WRAP_SRCS( ${_cuda_target} ${format} _generated_files ${_sources}
+                  ${_cmake_options} OPTIONS ${_options} PHONY)
+
+  set( ${generated_files} ${_generated_files})
+
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE generated_files)
+  cuda_compile_base(cuda_compile OBJ ${generated_files} ${ARGN})
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE PTX
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE_PTX generated_files)
+  cuda_compile_base(cuda_compile_ptx PTX ${generated_files} ${ARGN})
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE FATBIN
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE_FATBIN generated_files)
+  cuda_compile_base(cuda_compile_fatbin FATBIN ${generated_files} ${ARGN})
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE CUBIN
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE_CUBIN generated_files)
+  cuda_compile_base(cuda_compile_cubin CUBIN ${generated_files} ${ARGN})
+endmacro()
+
+
+###############################################################################
+###############################################################################
+# CUDA ADD CUFFT TO TARGET
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_CUFFT_TO_TARGET target)
+  if (CUDA_BUILD_EMULATION)
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cufftemu_LIBRARY})
+  else()
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cufft_LIBRARY})
+  endif()
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA ADD CUBLAS TO TARGET
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_CUBLAS_TO_TARGET target)
+  if (CUDA_BUILD_EMULATION)
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublasemu_LIBRARY})
+  else()
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
+  endif()
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA BUILD CLEAN TARGET
+###############################################################################
+###############################################################################
+macro(CUDA_BUILD_CLEAN_TARGET)
+  # Call this after you add all your CUDA targets, and you will get a
+  # convenience target.  You should also make clean after running this target
+  # to get the build system to generate all the code again.
+
+  set(cuda_clean_target_name clean_cuda_depends)
+  if (CMAKE_GENERATOR MATCHES "Visual Studio")
+    string(TOUPPER ${cuda_clean_target_name} cuda_clean_target_name)
+  endif()
+  add_custom_target(${cuda_clean_target_name}
+    COMMAND ${CMAKE_COMMAND} -E remove ${CUDA_ADDITIONAL_CLEAN_FILES})
+
+  # Clear out the variable, so the next time we configure it will be empty.
+  # This is useful so that the files won't persist in the list after targets
+  # have been removed.
+  set(CUDA_ADDITIONAL_CLEAN_FILES "" CACHE INTERNAL "List of intermediate files that are part of the cuda dependency scanning.")
+endmacro()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..580f24a400d8c5662ec572c4631db9e3e47645d9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake
@@ -0,0 +1,106 @@
+#  James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#  Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#  Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#  Copyright (c) 2007-2009
+#  Scientific Computing and Imaging Institute, University of Utah
+#
+#  This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#  for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+
+#######################################################################
+# This converts a file written in makefile syntax into one that can be included
+# by CMake.
+
+# Input variables
+#
+# verbose:BOOL=<>          OFF: Be as quiet as possible (default)
+#                          ON : Extra output
+#
+# input_file:FILEPATH=<>   Path to dependency file in makefile format
+#
+# output_file:FILEPATH=<>  Path to file with dependencies in CMake readable variable
+#
+
+file(READ ${input_file} depend_text)
+
+if (NOT "${depend_text}" STREQUAL "")
+
+  # message("FOUND DEPENDS")
+
+  string(REPLACE "\\ " " " depend_text ${depend_text})
+
+  # This works for the nvcc -M generated dependency files.
+  string(REGEX REPLACE "^.* : " "" depend_text ${depend_text})
+  string(REGEX REPLACE "[ \\\\]*\n" ";" depend_text ${depend_text})
+
+  set(dependency_list "")
+
+  foreach(file ${depend_text})
+
+    string(REGEX REPLACE "^ +" "" file ${file})
+
+    # OK, now if we had a UNC path, nvcc has a tendency to only output the first '/'
+    # instead of '//'.  Here we will test to see if the file exists, if it doesn't then
+    # try to prepend another '/' to the path and test again.  If it still fails remove the
+    # path.
+
+    if(NOT EXISTS "${file}")
+      if (EXISTS "/${file}")
+        set(file "/${file}")
+      else()
+        if(verbose)
+          message(WARNING " Removing non-existent dependency file: ${file}")
+        endif()
+        set(file "")
+      endif()
+    endif()
+
+    # Make sure we check to see if we have a file, before asking if it is not a directory.
+    # if(NOT IS_DIRECTORY "") will return TRUE.
+    if(file AND NOT IS_DIRECTORY "${file}")
+      # If softlinks start to matter, we should change this to REALPATH.  For now we need
+      # to flatten paths, because nvcc can generate stuff like /bin/../include instead of
+      # just /include.
+      get_filename_component(file_absolute "${file}" ABSOLUTE)
+      list(APPEND dependency_list "${file_absolute}")
+    endif()
+
+  endforeach()
+
+else()
+  # message("FOUND NO DEPENDS")
+endif()
+
+# Remove the duplicate entries and sort them.
+list(REMOVE_DUPLICATES dependency_list)
+list(SORT dependency_list)
+
+foreach(file ${dependency_list})
+  string(APPEND cuda_nvcc_depend " \"${file}\"\n")
+endforeach()
+
+file(WRITE ${output_file} "# Generated by: make2cmake.cmake\nSET(CUDA_NVCC_DEPEND\n ${cuda_nvcc_depend})\n\n")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..25ceb49f3dd8e684e35cac49834c4db0aa5c338a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake
@@ -0,0 +1,109 @@
+#  James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#  Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#  Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#  Copyright (c) 2007-2009
+#  Scientific Computing and Imaging Institute, University of Utah
+#
+#  This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#  for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+
+#######################################################################
+# Parses a .cubin file produced by nvcc and reports statistics about the file.
+
+
+file(READ ${input_file} file_text)
+
+if (NOT "${file_text}" STREQUAL "")
+
+  string(REPLACE ";" "\\;" file_text ${file_text})
+  string(REPLACE "\ncode" ";code" file_text ${file_text})
+
+  list(LENGTH file_text len)
+
+  foreach(line ${file_text})
+
+    # Only look at "code { }" blocks.
+    if(line MATCHES "^code")
+
+      # Break into individual lines.
+      string(REGEX REPLACE "\n" ";" line ${line})
+
+      foreach(entry ${line})
+
+        # Extract kernel names.
+        if (${entry} MATCHES "[^g]name = ([^ ]+)")
+          set(entry "${CMAKE_MATCH_1}")
+
+          # Check to see if the kernel name starts with "_"
+          set(skip FALSE)
+          # if (${entry} MATCHES "^_")
+            # Skip the rest of this block.
+            # message("Skipping ${entry}")
+            # set(skip TRUE)
+          # else ()
+            message("Kernel:    ${entry}")
+          # endif ()
+
+        endif()
+
+        # Skip the rest of the block if necessary
+        if(NOT skip)
+
+          # Registers
+          if (${entry} MATCHES "reg([ ]+)=([ ]+)([^ ]+)")
+            set(entry "${CMAKE_MATCH_3}")
+            message("Registers: ${entry}")
+          endif()
+
+          # Local memory
+          if (${entry} MATCHES "lmem([ ]+)=([ ]+)([^ ]+)")
+            set(entry "${CMAKE_MATCH_3}")
+            message("Local:     ${entry}")
+          endif()
+
+          # Shared memory
+          if (${entry} MATCHES "smem([ ]+)=([ ]+)([^ ]+)")
+            set(entry "${CMAKE_MATCH_3}")
+            message("Shared:    ${entry}")
+          endif()
+
+          if (${entry} MATCHES "^}")
+            message("")
+          endif()
+
+        endif()
+
+
+      endforeach()
+
+    endif()
+
+  endforeach()
+
+else()
+  # message("FOUND NO DEPENDS")
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..59c5c11a1091f34df89b681a926db602a1c75caa
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake
@@ -0,0 +1,303 @@
+#  James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#
+#  Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#  This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#  for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+
+##########################################################################
+# This file runs the nvcc commands to produce the desired output file along with
+# the dependency file needed by CMake to compute dependencies.  In addition the
+# file checks the output of each command and if the command fails it deletes the
+# output files.
+
+# Input variables
+#
+# verbose:BOOL=<>          OFF: Be as quiet as possible (default)
+#                          ON : Describe each step
+#
+# build_configuration:STRING=<> Typically one of Debug, MinSizeRel, Release, or
+#                               RelWithDebInfo, but it should match one of the
+#                               entries in CUDA_HOST_FLAGS. This is the build
+#                               configuration used when compiling the code.  If
+#                               blank or unspecified Debug is assumed as this is
+#                               what CMake does.
+#
+# generated_file:STRING=<> File to generate.  This argument must be passed in.
+#
+# generated_cubin_file:STRING=<> File to generate.  This argument must be passed
+#                                                   in if build_cubin is true.
+
+cmake_policy(PUSH)
+cmake_policy(SET CMP0007 NEW)
+cmake_policy(SET CMP0010 NEW)
+if(NOT generated_file)
+  message(FATAL_ERROR "You must specify generated_file on the command line")
+endif()
+
+# Set these up as variables to make reading the generated file easier
+set(CMAKE_COMMAND "@CMAKE_COMMAND@") # path
+set(source_file "@source_file@") # path
+set(NVCC_generated_dependency_file "@NVCC_generated_dependency_file@") # path
+set(cmake_dependency_file "@cmake_dependency_file@") # path
+set(CUDA_make2cmake "@CUDA_make2cmake@") # path
+set(CUDA_parse_cubin "@CUDA_parse_cubin@") # path
+set(build_cubin @build_cubin@) # bool
+set(CUDA_HOST_COMPILER "@CUDA_HOST_COMPILER@") # path
+# We won't actually use these variables for now, but we need to set this, in
+# order to force this file to be run again if it changes.
+set(generated_file_path "@generated_file_path@") # path
+set(generated_file_internal "@generated_file@") # path
+set(generated_cubin_file_internal "@generated_cubin_file@") # path
+
+set(CUDA_NVCC_EXECUTABLE "@CUDA_NVCC_EXECUTABLE@") # path
+set(CUDA_NVCC_FLAGS @CUDA_NVCC_FLAGS@ ;; @CUDA_WRAP_OPTION_NVCC_FLAGS@) # list
+@CUDA_NVCC_FLAGS_CONFIG@
+set(nvcc_flags @nvcc_flags@) # list
+set(CUDA_NVCC_INCLUDE_DIRS [==[@CUDA_NVCC_INCLUDE_DIRS@]==]) # list (needs to be in lua quotes to address backslashes)
+string(REPLACE "\\" "/" CUDA_NVCC_INCLUDE_DIRS "${CUDA_NVCC_INCLUDE_DIRS}")
+set(CUDA_NVCC_COMPILE_DEFINITIONS [==[@CUDA_NVCC_COMPILE_DEFINITIONS@]==]) # list (needs to be in lua quotes see #16510 ).
+set(format_flag "@format_flag@") # string
+set(cuda_language_flag @cuda_language_flag@) # list
+
+# Clean up list of include directories and add -I flags
+list(REMOVE_DUPLICATES CUDA_NVCC_INCLUDE_DIRS)
+set(CUDA_NVCC_INCLUDE_ARGS)
+foreach(dir ${CUDA_NVCC_INCLUDE_DIRS})
+  # Extra quotes are added around each flag to help nvcc parse out flags with spaces.
+  list(APPEND CUDA_NVCC_INCLUDE_ARGS "-I${dir}")
+endforeach()
+
+# Clean up list of compile definitions, add -D flags, and append to nvcc_flags
+list(REMOVE_DUPLICATES CUDA_NVCC_COMPILE_DEFINITIONS)
+foreach(def ${CUDA_NVCC_COMPILE_DEFINITIONS})
+  list(APPEND nvcc_flags "-D${def}")
+endforeach()
+
+if(build_cubin AND NOT generated_cubin_file)
+  message(FATAL_ERROR "You must specify generated_cubin_file on the command line")
+endif()
+
+# This is the list of host compilation flags.  It C or CXX should already have
+# been chosen by FindCUDA.cmake.
+@CUDA_HOST_FLAGS@
+
+# Take the compiler flags and package them up to be sent to the compiler via -Xcompiler
+set(nvcc_host_compiler_flags "")
+# If we weren't given a build_configuration, use Debug.
+if(NOT build_configuration)
+  set(build_configuration Debug)
+endif()
+string(TOUPPER "${build_configuration}" build_configuration)
+#message("CUDA_NVCC_HOST_COMPILER_FLAGS = ${CUDA_NVCC_HOST_COMPILER_FLAGS}")
+foreach(flag ${CMAKE_HOST_FLAGS} ${CMAKE_HOST_FLAGS_${build_configuration}})
+  # Extra quotes are added around each flag to help nvcc parse out flags with spaces.
+  string(APPEND nvcc_host_compiler_flags ",\"${flag}\"")
+endforeach()
+if (nvcc_host_compiler_flags)
+  set(nvcc_host_compiler_flags "-Xcompiler" ${nvcc_host_compiler_flags})
+endif()
+#message("nvcc_host_compiler_flags = \"${nvcc_host_compiler_flags}\"")
+# Add the build specific configuration flags
+list(APPEND CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS_${build_configuration}})
+
+# Any -ccbin existing in CUDA_NVCC_FLAGS gets highest priority
+list( FIND CUDA_NVCC_FLAGS "-ccbin" ccbin_found0 )
+list( FIND CUDA_NVCC_FLAGS "--compiler-bindir" ccbin_found1 )
+if( ccbin_found0 LESS 0 AND ccbin_found1 LESS 0 AND CUDA_HOST_COMPILER )
+  if (CUDA_HOST_COMPILER STREQUAL "@_CUDA_MSVC_HOST_COMPILER@" AND DEFINED CCBIN)
+    set(CCBIN -ccbin "${CCBIN}")
+  else()
+    set(CCBIN -ccbin "${CUDA_HOST_COMPILER}")
+  endif()
+endif()
+
+# cuda_execute_process - Executes a command with optional command echo and status message.
+#
+#   status  - Status message to print if verbose is true
+#   command - COMMAND argument from the usual execute_process argument structure
+#   ARGN    - Remaining arguments are the command with arguments
+#
+#   CUDA_result - return value from running the command
+#
+# Make this a macro instead of a function, so that things like RESULT_VARIABLE
+# and other return variables are present after executing the process.
+macro(cuda_execute_process status command)
+  set(_command ${command})
+  if(NOT "x${_command}" STREQUAL "xCOMMAND")
+    message(FATAL_ERROR "Malformed call to cuda_execute_process.  Missing COMMAND as second argument. (command = ${command})")
+  endif()
+  if(verbose)
+    execute_process(COMMAND "${CMAKE_COMMAND}" -E echo -- ${status})
+    # Now we need to build up our command string.  We are accounting for quotes
+    # and spaces, anything else is left up to the user to fix if they want to
+    # copy and paste a runnable command line.
+    set(cuda_execute_process_string)
+    foreach(arg ${ARGN})
+      # If there are quotes, escape them, so they come through.
+      string(REPLACE "\"" "\\\"" arg ${arg})
+      # Args with spaces need quotes around them to get them to be parsed as a single argument.
+      if(arg MATCHES " ")
+        list(APPEND cuda_execute_process_string "\"${arg}\"")
+      else()
+        list(APPEND cuda_execute_process_string ${arg})
+      endif()
+    endforeach()
+    # Echo the command
+    execute_process(COMMAND ${CMAKE_COMMAND} -E echo ${cuda_execute_process_string})
+  endif()
+  # Run the command
+  execute_process(COMMAND ${ARGN} RESULT_VARIABLE CUDA_result )
+endmacro()
+
+# Delete the target file
+cuda_execute_process(
+  "Removing ${generated_file}"
+  COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}"
+  )
+
+# For CUDA 2.3 and below, -G -M doesn't work, so remove the -G flag
+# for dependency generation and hope for the best.
+set(depends_CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}")
+set(CUDA_VERSION @CUDA_VERSION@)
+
+# nvcc doesn't define __CUDACC__ for some reason when generating dependency files.  This
+# can cause incorrect dependencies when #including files based on this macro which is
+# defined in the generating passes of nvcc invocation.  We will go ahead and manually
+# define this for now until a future version fixes this bug.
+set(CUDACC_DEFINE -D__CUDACC__)
+
+# Generate the dependency file
+cuda_execute_process(
+  "Generating dependency file: ${NVCC_generated_dependency_file}"
+  COMMAND "${CUDA_NVCC_EXECUTABLE}"
+  -M
+  ${CUDACC_DEFINE}
+  "${source_file}"
+  -o "${NVCC_generated_dependency_file}"
+  ${CCBIN}
+  ${nvcc_flags}
+  ${nvcc_host_compiler_flags}
+  ${depends_CUDA_NVCC_FLAGS}
+  -DNVCC
+  ${CUDA_NVCC_INCLUDE_ARGS}
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Generate the cmake readable dependency file to a temp file.  Don't put the
+# quotes just around the filenames for the input_file and output_file variables.
+# CMake will pass the quotes through and not be able to find the file.
+cuda_execute_process(
+  "Generating temporary cmake readable file: ${cmake_dependency_file}.tmp"
+  COMMAND "${CMAKE_COMMAND}"
+  -D "input_file:FILEPATH=${NVCC_generated_dependency_file}"
+  -D "output_file:FILEPATH=${cmake_dependency_file}.tmp"
+  -D "verbose=${verbose}"
+  -P "${CUDA_make2cmake}"
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Copy the file if it is different
+cuda_execute_process(
+  "Copy if different ${cmake_dependency_file}.tmp to ${cmake_dependency_file}"
+  COMMAND "${CMAKE_COMMAND}" -E copy_if_different "${cmake_dependency_file}.tmp" "${cmake_dependency_file}"
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Delete the temporary file
+cuda_execute_process(
+  "Removing ${cmake_dependency_file}.tmp and ${NVCC_generated_dependency_file}"
+  COMMAND "${CMAKE_COMMAND}" -E remove "${cmake_dependency_file}.tmp" "${NVCC_generated_dependency_file}"
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Generate the code
+cuda_execute_process(
+  "Generating ${generated_file}"
+  COMMAND "${CUDA_NVCC_EXECUTABLE}"
+  "${source_file}"
+  ${cuda_language_flag}
+  ${format_flag} -o "${generated_file}"
+  ${CCBIN}
+  ${nvcc_flags}
+  ${nvcc_host_compiler_flags}
+  ${CUDA_NVCC_FLAGS}
+  -DNVCC
+  ${CUDA_NVCC_INCLUDE_ARGS}
+  )
+
+if(CUDA_result)
+  # Since nvcc can sometimes leave half done files make sure that we delete the output file.
+  cuda_execute_process(
+    "Removing ${generated_file}"
+    COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}"
+    )
+  message(FATAL_ERROR "Error generating file ${generated_file}")
+else()
+  if(verbose)
+    message("Generated ${generated_file} successfully.")
+  endif()
+endif()
+
+# Cubin resource report commands.
+if( build_cubin )
+  # Run with -cubin to produce resource usage report.
+  cuda_execute_process(
+    "Generating ${generated_cubin_file}"
+    COMMAND "${CUDA_NVCC_EXECUTABLE}"
+    "${source_file}"
+    ${CUDA_NVCC_FLAGS}
+    ${nvcc_flags}
+    ${CCBIN}
+    ${nvcc_host_compiler_flags}
+    -DNVCC
+    -cubin
+    -o "${generated_cubin_file}"
+    ${CUDA_NVCC_INCLUDE_ARGS}
+    )
+
+  # Execute the parser script.
+  cuda_execute_process(
+    "Executing the parser script"
+    COMMAND  "${CMAKE_COMMAND}"
+    -D "input_file:STRING=${generated_cubin_file}"
+    -P "${CUDA_parse_cubin}"
+    )
+
+endif()
+
+cmake_policy(POP)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..bf7edd69ccd13990b24350fdf217b156343724f4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
@@ -0,0 +1,300 @@
+# Synopsis:
+#   CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
+#   -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
+#      target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
+#       - "Auto" detects local machine GPU compute arch at runtime.
+#       - "Common" and "All" cover common and entire subsets of architectures
+#      ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
+#      NAME: Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere
+#      NUM: Any number. Only those pairs are currently accepted by NVCC though:
+#            3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 8.0
+#      Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
+#      Additionally, sets ${out_variable}_readable to the resulting numeric list
+#      Example:
+#       CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
+#        LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
+#
+#      More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
+#
+
+if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
+  if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA"
+      AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)")
+    set(CUDA_VERSION "${CMAKE_MATCH_1}")
+  endif()
+endif()
+
+# See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list
+
+# This list will be used for CUDA_ARCH_NAME = All option
+set(CUDA_KNOWN_GPU_ARCHITECTURES  "Kepler" "Maxwell")
+
+# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
+set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0")
+
+# This list is used to filter CUDA archs when autodetecting
+set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0")
+
+if(CUDA_VERSION VERSION_GREATER "10.5")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0")
+
+  if(CUDA_VERSION VERSION_LESS "11.1")
+    set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX")
+  endif()
+endif()
+
+if(NOT CUDA_VERSION VERSION_LESS "11.1")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6")
+  set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6")
+
+  if(CUDA_VERSION VERSION_LESS "11.8")
+    set(CUDA_LIMIT_GPU_ARCHITECTURE "8.9")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX")
+  endif()
+endif()
+
+if(NOT CUDA_VERSION VERSION_LESS "11.8")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ada")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Hopper")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0")
+
+  if(CUDA_VERSION VERSION_LESS "12.0")
+    set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX")
+  endif()
+endif()
+
+if(NOT CUDA_VERSION VERSION_LESS "12.0")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a")
+  list(REMOVE_ITEM CUDA_COMMON_GPU_ARCHITECTURES "3.5")
+  list(REMOVE_ITEM CUDA_ALL_GPU_ARCHITECTURES "3.5")
+endif()
+
+if(CUDA_VERSION VERSION_GREATER "12.6")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0a")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.1a")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.1a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0a")
+endif()
+
+
+################################################################################################
+# A function for automatic detection of GPUs installed  (if autodetection is enabled)
+# Usage:
+#   CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
+#
+function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
+  if(NOT CUDA_GPU_DETECT_OUTPUT)
+    if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
+      set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu")
+    else()
+      set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp")
+    endif()
+
+    file(WRITE ${file} ""
+      "#include \n"
+      "#include \n"
+      "int main()\n"
+      "{\n"
+      "  int count = 0;\n"
+      "  if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
+      "  if (count == 0) return -1;\n"
+      "  for (int device = 0; device < count; ++device)\n"
+      "  {\n"
+      "    cudaDeviceProp prop;\n"
+      "    if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
+      "      std::printf(\"%d.%d \", prop.major, prop.minor);\n"
+      "  }\n"
+      "  return 0;\n"
+      "}\n")
+
+    if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
+      try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
+              RUN_OUTPUT_VARIABLE compute_capabilities)
+    else()
+      try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
+              CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
+              LINK_LIBRARIES ${CUDA_LIBRARIES}
+              RUN_OUTPUT_VARIABLE compute_capabilities)
+    endif()
+
+    # Filter unrelated content out of the output.
+    string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}")
+
+    if(run_result EQUAL 0)
+      string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}")
+      set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities}
+        CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE)
+    endif()
+  endif()
+
+  if(NOT CUDA_GPU_DETECT_OUTPUT)
+    message(STATUS "Automatic GPU detection failed. Building for common architectures.")
+    set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE)
+  else()
+    # Filter based on CUDA version supported archs
+    set(CUDA_GPU_DETECT_OUTPUT_FILTERED "")
+    separate_arguments(CUDA_GPU_DETECT_OUTPUT)
+    foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT})
+        if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR
+                                            ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE))
+        list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM)
+        string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}")
+      else()
+        string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}")
+      endif()
+    endforeach()
+
+    set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE)
+  endif()
+endfunction()
+
+
+################################################################################################
+# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list
+# Usage:
+#   SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs])
+function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
+  set(CUDA_ARCH_LIST "${ARGN}")
+
+  if("X${CUDA_ARCH_LIST}" STREQUAL "X" )
+    set(CUDA_ARCH_LIST "Auto")
+  endif()
+
+  set(cuda_arch_bin)
+  set(cuda_arch_ptx)
+
+  if("${CUDA_ARCH_LIST}" STREQUAL "All")
+    set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES})
+  elseif("${CUDA_ARCH_LIST}" STREQUAL "Common")
+    set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES})
+  elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto")
+    CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST)
+    message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}")
+  endif()
+
+  # Now process the list and look for names
+  string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
+  list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
+  foreach(arch_name ${CUDA_ARCH_LIST})
+    set(arch_bin)
+    set(arch_ptx)
+    set(add_ptx FALSE)
+    # Check to see if we are compiling PTX
+    if(arch_name MATCHES "(.*)\\+PTX$")
+      set(add_ptx TRUE)
+      set(arch_name ${CMAKE_MATCH_1})
+    endif()
+    if(arch_name MATCHES "^([0-9]+\\.[0-9][af]?(\\([0-9]+\\.[0-9]\\))?)$")
+      set(arch_bin ${CMAKE_MATCH_1})
+      set(arch_ptx ${arch_bin})
+    else()
+      # Look for it in our list of known architectures
+      if(${arch_name} STREQUAL "Kepler+Tesla")
+        set(arch_bin 3.7)
+      elseif(${arch_name} STREQUAL "Kepler")
+        set(arch_bin 3.5)
+        set(arch_ptx 3.5)
+      elseif(${arch_name} STREQUAL "Maxwell+Tegra")
+        set(arch_bin 5.3)
+      elseif(${arch_name} STREQUAL "Maxwell")
+        set(arch_bin 5.0 5.2)
+        set(arch_ptx 5.2)
+      elseif(${arch_name} STREQUAL "Pascal")
+        set(arch_bin 6.0 6.1)
+        set(arch_ptx 6.1)
+     elseif(${arch_name} STREQUAL "Volta+Tegra")
+        set(arch_bin 7.2)
+      elseif(${arch_name} STREQUAL "Volta")
+        set(arch_bin 7.0 7.0)
+        set(arch_ptx 7.0)
+      elseif(${arch_name} STREQUAL "Turing")
+        set(arch_bin 7.5)
+        set(arch_ptx 7.5)
+      elseif(${arch_name} STREQUAL "Ampere+Tegra")
+        set(arch_bin 8.7)
+      elseif(${arch_name} STREQUAL "Ampere")
+        set(arch_bin 8.0 8.6)
+        set(arch_ptx 8.0 8.6)
+      elseif(${arch_name} STREQUAL "Ada")
+        set(arch_bin 8.9)
+        set(arch_ptx 8.9)
+      elseif(${arch_name} STREQUAL "Hopper")
+        set(arch_bin 9.0)
+        set(arch_ptx 9.0)
+      elseif(${arch_name} STREQUAL "Blackwell+Tegra")
+        set(arch_bin 10.1)
+      elseif(${arch_name} STREQUAL "Blackwell")
+        set(arch_bin 10.0 12.0)
+        set(arch_ptx 10.0 12.0)
+      else()
+        message(SEND_ERROR "Found Unknown CUDA Architecture Name in CUDA_SELECT_NVCC_ARCH_FLAGS: ${arch_name} ")
+      endif()
+    endif()
+    if(NOT arch_bin)
+      message(SEND_ERROR "arch_bin wasn't set for some reason")
+    endif()
+    list(APPEND cuda_arch_bin ${arch_bin})
+    if(add_ptx)
+      if (NOT arch_ptx)
+        set(arch_ptx ${arch_bin})
+      endif()
+      list(APPEND cuda_arch_ptx ${arch_ptx})
+    endif()
+  endforeach()
+
+  # remove dots and convert to lists
+  string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
+  string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}")
+  string(REGEX MATCHALL "[0-9()]+[af]?" cuda_arch_bin "${cuda_arch_bin}")
+  string(REGEX MATCHALL "[0-9]+[af]?"   cuda_arch_ptx "${cuda_arch_ptx}")
+
+  if(cuda_arch_bin)
+    list(REMOVE_DUPLICATES cuda_arch_bin)
+  endif()
+  if(cuda_arch_ptx)
+    list(REMOVE_DUPLICATES cuda_arch_ptx)
+  endif()
+
+  set(nvcc_flags "")
+  set(nvcc_archs_readable "")
+
+  # Tell NVCC to add binaries for the specified GPUs
+  foreach(arch ${cuda_arch_bin})
+    if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
+      # User explicitly specified ARCH for the concrete CODE
+      list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
+      list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
+    else()
+      # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE
+      list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
+      list(APPEND nvcc_archs_readable sm_${arch})
+    endif()
+  endforeach()
+
+  # Tell NVCC to add PTX intermediate code for the specified architectures
+  foreach(arch ${cuda_arch_ptx})
+    list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
+    list(APPEND nvcc_archs_readable compute_${arch})
+  endforeach()
+
+  string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
+  set(${out_variable}          ${nvcc_flags}          PARENT_SCOPE)
+  set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
+endfunction()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..6821cee4f77a9d84c74f2c140870a2163ae5a5f0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake
@@ -0,0 +1,47 @@
+# Distributed under the OSI-approved BSD 3-Clause License.  See accompanying
+# file Copyright.txt or https://cmake.org/licensing for details.
+
+#.rst:
+# FindPackageMessage
+# ------------------
+#
+#
+#
+# FIND_PACKAGE_MESSAGE( "message for user" "find result details")
+#
+# This macro is intended to be used in FindXXX.cmake modules files.  It
+# will print a message once for each unique find result.  This is useful
+# for telling the user where a package was found.  The first argument
+# specifies the name (XXX) of the package.  The second argument
+# specifies the message to display.  The third argument lists details
+# about the find result so that if they change the message will be
+# displayed again.  The macro also obeys the QUIET argument to the
+# find_package command.
+#
+# Example:
+#
+# ::
+#
+#   if(X11_FOUND)
+#     FIND_PACKAGE_MESSAGE(X11 "Found X11: ${X11_X11_LIB}"
+#       "[${X11_X11_LIB}][${X11_INCLUDE_DIR}]")
+#   else()
+#    ...
+#   endif()
+
+function(FIND_PACKAGE_MESSAGE pkg msg details)
+  # Avoid printing a message repeatedly for the same find result.
+  if(NOT ${pkg}_FIND_QUIETLY)
+    string(REPLACE "\n" "" details "${details}")
+    set(DETAILS_VAR FIND_PACKAGE_MESSAGE_DETAILS_${pkg})
+    if(NOT "${details}" STREQUAL "${${DETAILS_VAR}}")
+      # The message has not yet been printed.
+      message(STATUS "${msg}")
+
+      # Save the find details in the cache to avoid printing the same
+      # message again.
+      set("${DETAILS_VAR}" "${details}"
+        CACHE INTERNAL "Details about finding ${pkg}")
+    endif()
+  endif()
+endfunction()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..7ecaff5109f42efb336b30a6ef0ad429a30051d3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake
@@ -0,0 +1,257 @@
+set(PYTORCH_FOUND_HIP FALSE)
+
+# If ROCM_PATH is set, assume intention is to compile with
+# ROCm support and error out if the ROCM_PATH does not exist.
+# Else ROCM_PATH does not exist, assume a default of /opt/rocm
+# In the latter case, if /opt/rocm does not exist emit status
+# message and return.
+if(DEFINED ENV{ROCM_PATH})
+  file(TO_CMAKE_PATH "$ENV{ROCM_PATH}" ROCM_PATH)
+  if(NOT EXISTS ${ROCM_PATH})
+    message(FATAL_ERROR
+      "ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n"
+      "Set a valid ROCM_PATH or unset ROCM_PATH environment variable to fix.")
+  endif()
+else()
+  if(UNIX)
+    set(ROCM_PATH /opt/rocm)
+  else() # Win32
+    set(ROCM_PATH C:/opt/rocm)
+  endif()
+  if(NOT EXISTS ${ROCM_PATH})
+    message(STATUS
+        "ROCM_PATH environment variable is not set and ${ROCM_PATH} does not exist.\n"
+        "Building without ROCm support.")
+    return()
+  endif()
+endif()
+
+# MAGMA_HOME
+if(NOT DEFINED ENV{MAGMA_HOME})
+  set(MAGMA_HOME ${ROCM_PATH}/magma)
+  set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma)
+else()
+  file(TO_CMAKE_PATH "$ENV{MAGMA_HOME}" MAGMA_HOME)
+endif()
+
+# MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different
+# installation directory.
+if(WIN32)
+  if(NOT DEFINED ENV{MIOPEN_PATH})
+    set(miopen_DIR C:/opt/miopen/lib/cmake/miopen)
+  else()
+    set(miopen_DIR $ENV{MIOPEN_PATH}/lib/cmake/miopen)
+  endif()
+endif()
+
+torch_hip_get_arch_list(PYTORCH_ROCM_ARCH)
+if(PYTORCH_ROCM_ARCH STREQUAL "")
+  message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.")
+endif()
+message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}")
+
+# Add HIP to the CMAKE Module Path
+# needed because the find_package call to this module uses the Module mode search
+# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes
+if(UNIX)
+  set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
+else() # Win32
+  set(CMAKE_MODULE_PATH ${ROCM_PATH}/cmake/ ${CMAKE_MODULE_PATH})
+endif()
+
+# Add ROCM_PATH to CMAKE_PREFIX_PATH, needed because the find_package
+# call to individual ROCM components uses the Config mode search
+list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
+
+macro(find_package_and_print_version PACKAGE_NAME)
+  find_package("${PACKAGE_NAME}" ${ARGN})
+  if(NOT ${PACKAGE_NAME}_FOUND)
+    message("Optional package ${PACKAGE_NAME} not found")
+  else()
+    message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
+    if(${PACKAGE_NAME}_INCLUDE_DIR)
+      list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR})
+    endif()
+  endif()
+endmacro()
+
+# Find the HIP Package
+# MODULE argument is added for clarity that CMake is searching
+# for FindHIP.cmake in Module mode
+find_package_and_print_version(HIP 1.0 MODULE)
+
+if(HIP_FOUND)
+  set(PYTORCH_FOUND_HIP TRUE)
+  find_package_and_print_version(hip REQUIRED CONFIG)
+  if(HIP_VERSION)
+    # Check if HIP_VERSION contains a dash (e.g., "7.1.25421-32f9fa6ca5")
+    # and strip everything after it to get clean numeric version
+    string(FIND "${HIP_VERSION}" "-" DASH_POS)
+    if(NOT DASH_POS EQUAL -1)
+      string(SUBSTRING "${HIP_VERSION}" 0 ${DASH_POS} HIP_VERSION_CLEAN)
+      set(HIP_VERSION "${HIP_VERSION_CLEAN}")
+  endif()
+  message("HIP version: ${HIP_VERSION}")
+endif()
+
+# The rocm-core package was only introduced in ROCm 6.4, so we make it optional.
+  find_package(rocm-core CONFIG)
+
+  # Some old consumer HIP SDKs do not distribute rocm_version.h, so we allow
+  # falling back to the hip version, which everyone should have.
+  # rocm_version.h lives in the rocm-core package and hip_version.h lives in the
+  # hip (lower-case) package. Both are probed above and will be in
+  # ROCM_INCLUDE_DIRS if available.
+  find_file(ROCM_VERSION_HEADER_PATH
+    NAMES rocm-core/rocm_version.h hip/hip_version.h
+    NO_DEFAULT_PATH
+    PATHS ${ROCM_INCLUDE_DIRS}
+  )
+  if(ROCM_VERSION_HEADER_PATH MATCHES "rocm-core/rocm_version.h$")
+    set(ROCM_LIB_NAME "ROCM")
+  else()
+    set(ROCM_LIB_NAME "HIP")
+  endif()
+
+  if(NOT ROCM_VERSION_HEADER_PATH)
+    message(FATAL_ERROR "Could not find hip/hip_version.h or rocm-core/rocm_version.h in ${ROCM_INCLUDE_DIRS}")
+  endif()
+  get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME)
+
+  if(EXISTS ${ROCM_VERSION_HEADER_PATH})
+    set(ROCM_HEADER_FILE ${ROCM_VERSION_HEADER_PATH})
+  else()
+    message(FATAL_ERROR "********************* ${ROCM_HEADER_NAME} could not be found ******************\n")
+  endif()
+
+  # Read the ROCM headerfile into a variable
+  message(STATUS "Reading ROCM version from: ${ROCM_HEADER_FILE}")
+  message(STATUS "Content: ${ROCM_HEADER_CONTENT}")
+  file(READ "${ROCM_HEADER_FILE}" ROCM_HEADER_CONTENT)
+
+  # Below we use a RegEx to find ROCM version numbers.
+  # Note that CMake does not support \s for blank space. That is
+  # why in the regular expressions below we have a blank space in
+  # the square brackets.
+  # There are three steps:
+  # 1. Match regular expression
+  # 2. Strip the non-numerical part of the string
+  # 3. Strip leading and trailing spaces
+
+  string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
+  string(REPLACE "${ROCM_LIB_NAME}_VERSION_MAJOR" "" TEMP2 ${TEMP1})
+  string(STRIP ${TEMP2} ROCM_VERSION_DEV_MAJOR)
+  string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
+  string(REPLACE "${ROCM_LIB_NAME}_VERSION_MINOR" "" TEMP2 ${TEMP1})
+  string(STRIP ${TEMP2} ROCM_VERSION_DEV_MINOR)
+  string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
+  string(REPLACE "${ROCM_LIB_NAME}_VERSION_PATCH" "" TEMP2 ${TEMP1})
+  string(STRIP ${TEMP2} ROCM_VERSION_DEV_PATCH)
+
+  # Create ROCM_VERSION_DEV_INT which is later used as a preprocessor macros
+  set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
+  math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
+
+  message("\n***** ROCm version from ${ROCM_HEADER_NAME} ****\n")
+  message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
+  message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
+  message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
+  message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")
+  message("ROCM_VERSION_DEV_INT:   ${ROCM_VERSION_DEV_INT}")
+
+  math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}")
+  message("HIP_VERSION_MAJOR: ${HIP_VERSION_MAJOR}")
+  message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}")
+  message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}")
+
+  # Find ROCM components using Config mode
+  # These components will be searced for recursively in ${ROCM_PATH}
+  message("\n***** Library versions from cmake find_package *****\n")
+  find_package_and_print_version(amd_comgr REQUIRED)
+  find_package_and_print_version(rocrand REQUIRED)
+  find_package_and_print_version(hiprand REQUIRED)
+  find_package_and_print_version(rocblas REQUIRED)
+  find_package_and_print_version(hipblas REQUIRED)
+  find_package_and_print_version(miopen REQUIRED)
+  find_package_and_print_version(hipfft REQUIRED)
+  find_package_and_print_version(hipsparse REQUIRED)
+  find_package_and_print_version(rocprim REQUIRED)
+  find_package_and_print_version(hipcub REQUIRED)
+  find_package_and_print_version(rocthrust REQUIRED)
+  find_package_and_print_version(hipsolver REQUIRED)
+  find_package_and_print_version(rocsolver REQUIRED)
+  # workaround cmake 4 build issue
+  if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0")
+    message(WARNING "Work around hiprtc cmake failure for cmake >= 4")
+    set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
+    find_package_and_print_version(hiprtc REQUIRED)
+    unset(CMAKE_POLICY_VERSION_MINIMUM)
+  else()
+    find_package_and_print_version(hiprtc REQUIRED)
+  endif()
+  find_package_and_print_version(hipblaslt REQUIRED)
+
+  if(UNIX)
+    find_package_and_print_version(rccl)
+    find_package_and_print_version(hsa-runtime64 REQUIRED)
+  endif()
+
+  # Optional components.
+  find_package_and_print_version(hipsparselt)  # Will be required when ready.
+
+  list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)
+
+  if(UNIX)
+    # roctx is part of roctracer
+    find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
+
+    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
+
+    if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
+      # check whether hipblaslt provides HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
+      set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_outer_vec.cc")
+      file(WRITE ${file} ""
+        "#define LEGACY_HIPBLAS_DIRECT\n"
+        "#include \n"
+        "int main() {\n"
+        "    hipblasLtMatmulMatrixScale_t attr = HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;\n"
+        "    return 0;\n"
+        "}\n"
+        )
+      try_compile(hipblaslt_compile_result_outer_vec ${PROJECT_RANDOM_BINARY_DIR} ${file}
+        CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
+        COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
+        OUTPUT_VARIABLE hipblaslt_compile_output_outer_vec)
+
+      # check whether hipblaslt provides HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT
+      set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_vec_ext.cc")
+      file(WRITE ${file} ""
+        "#define LEGACY_HIPBLAS_DIRECT\n"
+        "#include \n"
+        "int main() {\n"
+        "    hipblasLtMatmulDescAttributes_t attr = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;\n"
+        "    return 0;\n"
+        "}\n"
+        )
+      try_compile(hipblaslt_compile_result_vec_ext ${PROJECT_RANDOM_BINARY_DIR} ${file}
+        CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
+        COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
+        OUTPUT_VARIABLE hipblaslt_compile_output_vec_ext)
+
+      if(hipblaslt_compile_result_outer_vec)
+        set(HIPBLASLT_OUTER_VEC ON)
+        set(HIPBLASLT_VEC_EXT OFF)
+        message("hipblaslt is using scale pointer outer vec")
+      elseif(hipblaslt_compile_result_vec_ext)
+        set(HIPBLASLT_OUTER_VEC OFF)
+        set(HIPBLASLT_VEC_EXT ON)
+        message("hipblaslt is using scale pointer vec ext")
+      else()
+        set(HIPBLASLT_OUTER_VEC OFF)
+        set(HIPBLASLT_VEC_EXT OFF)
+        message("hipblaslt is NOT using scale pointer outer vec: ${hipblaslt_compile_output_outer_vec}")
+        message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output_vec_ext}")
+      endif()
+    endif()
+  endif()
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..bc8855d23e61fbbe5979beae22ab6086a388ba1f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake
@@ -0,0 +1,391 @@
+# ---[ cuda
+
+# Poor man's include guard
+if(TARGET torch::cudart)
+  return()
+endif()
+
+# sccache is only supported in CMake master and not in the newest official
+# release (3.11.3) yet. Hence we need our own Modules_CUDA_fix to enable sccache.
+list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/../Modules_CUDA_fix)
+
+# We don't want to statically link cudart, because we rely on it's dynamic linkage in
+# python (follow along torch/cuda/__init__.py and usage of cudaGetErrorName).
+# Technically, we can link cudart here statically, and link libtorch_python.so
+# to a dynamic libcudart.so, but that's just wasteful.
+# However, on Windows, if this one gets switched off, the error "cuda: unknown error"
+# will be raised when running the following code:
+# >>> import torch
+# >>> torch.cuda.is_available()
+# >>> torch.cuda.current_device()
+# More details can be found in the following links.
+# https://github.com/pytorch/pytorch/issues/20635
+# https://github.com/pytorch/pytorch/issues/17108
+if(NOT MSVC)
+  set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "")
+endif()
+
+# Find CUDA.
+find_package(CUDA)
+if(NOT CUDA_FOUND)
+  # If user explicitly set USE_CUDA=1, error out instead of falling back
+  if(_USE_CUDA_EXPLICITLY_SET AND USE_CUDA)
+    message(FATAL_ERROR
+      "PyTorch: CUDA was explicitly requested (USE_CUDA=1) but cannot be found. "
+      "Please check your CUDA installation, ensure CUDA toolkit is installed, "
+      "and that CUDA_HOME or CMAKE_CUDA_COMPILER is set correctly. "
+      "If you want to build without CUDA, please set USE_CUDA=0.")
+  endif()
+
+  message(WARNING
+    "PyTorch: CUDA cannot be found. Depending on whether you are building "
+    "PyTorch or a PyTorch dependent library, the next warning / error will "
+    "give you more info.")
+  set(CAFFE2_USE_CUDA OFF)
+  return()
+endif()
+
+# Enable CUDA language support
+set(CUDAToolkit_ROOT "${CUDA_TOOLKIT_ROOT_DIR}")
+# Pass clang as host compiler, which according to the docs
+# Must be done before CUDA language is enabled, see
+# https://cmake.org/cmake/help/v3.15/variable/CMAKE_CUDA_HOST_COMPILER.html
+if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
+  set(CMAKE_CUDA_HOST_COMPILER "${CMAKE_CXX_COMPILER}")
+endif()
+enable_language(CUDA)
+if("X${CMAKE_CUDA_STANDARD}" STREQUAL "X" )
+  set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD})
+endif()
+set(CMAKE_CUDA_STANDARD_REQUIRED ON)
+
+# CMP0074 - find_package will respect _ROOT variables
+cmake_policy(PUSH)
+if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.12.0)
+  cmake_policy(SET CMP0074 NEW)
+endif()
+
+find_package(CUDAToolkit REQUIRED)
+
+cmake_policy(POP)
+
+if(NOT CMAKE_CUDA_COMPILER_VERSION VERSION_EQUAL CUDAToolkit_VERSION)
+  message(FATAL_ERROR "Found two conflicting CUDA versions:\n"
+                      "V${CMAKE_CUDA_COMPILER_VERSION} in '${CUDA_INCLUDE_DIRS}' and\n"
+                      "V${CUDAToolkit_VERSION} in '${CUDAToolkit_INCLUDE_DIRS}'")
+endif()
+
+message(STATUS "PyTorch: CUDA detected: " ${CUDA_VERSION})
+message(STATUS "PyTorch: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE})
+message(STATUS "PyTorch: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR})
+if(CUDA_VERSION VERSION_LESS 12.0)
+  message(FATAL_ERROR "PyTorch requires CUDA 12.0 or above.")
+endif()
+
+if(CUDA_FOUND)
+  # Sometimes, we may mismatch nvcc with the CUDA headers we are
+  # compiling with, e.g., if a ccache nvcc is fed to us by CUDA_NVCC_EXECUTABLE
+  # but the PATH is not consistent with CUDA_HOME.  It's better safe
+  # than sorry: make sure everything is consistent.
+  if(MSVC AND CMAKE_GENERATOR MATCHES "Visual Studio")
+    # When using Visual Studio, it attempts to lock the whole binary dir when
+    # `try_run` is called, which will cause the build to fail.
+    string(RANDOM BUILD_SUFFIX)
+    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}/${BUILD_SUFFIX}")
+  else()
+    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
+  endif()
+  set(file "${PROJECT_BINARY_DIR}/detect_cuda_version.cc")
+  file(WRITE ${file} ""
+    "#include \n"
+    "#include \n"
+    "int main() {\n"
+    "  printf(\"%d.%d\", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);\n"
+    "  return 0;\n"
+    "}\n"
+    )
+  if(NOT CMAKE_CROSSCOMPILING)
+    try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
+      CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
+      LINK_LIBRARIES ${CUDA_LIBRARIES}
+      RUN_OUTPUT_VARIABLE cuda_version_from_header
+      COMPILE_OUTPUT_VARIABLE output_var
+      )
+    if(NOT compile_result)
+      message(FATAL_ERROR "PyTorch: Couldn't determine version from header: " ${output_var})
+    endif()
+    message(STATUS "PyTorch: Header version is: " ${cuda_version_from_header})
+    if(NOT cuda_version_from_header STREQUAL ${CUDA_VERSION_STRING})
+      # Force CUDA to be processed for again next time
+      # TODO: I'm not sure if this counts as an implementation detail of
+      # FindCUDA
+      set(cuda_version_from_findcuda ${CUDA_VERSION_STRING})
+      unset(CUDA_TOOLKIT_ROOT_DIR_INTERNAL CACHE)
+      # Not strictly necessary, but for good luck.
+      unset(CUDA_VERSION CACHE)
+      # Error out
+      message(FATAL_ERROR "FindCUDA says CUDA version is ${cuda_version_from_findcuda} (usually determined by nvcc), "
+        "but the CUDA headers say the version is ${cuda_version_from_header}.  This often occurs "
+        "when you set both CUDA_HOME and CUDA_NVCC_EXECUTABLE to "
+        "non-standard locations, without also setting PATH to point to the correct nvcc.  "
+        "Perhaps, try re-running this command again with PATH=${CUDA_TOOLKIT_ROOT_DIR}/bin:$PATH.  "
+        "See above log messages for more diagnostics, and see https://github.com/pytorch/pytorch/issues/8092 for more details.")
+    endif()
+  endif()
+endif()
+
+# ---[ CUDA libraries wrapper
+
+# find lbnvrtc.so
+set(CUDA_NVRTC_LIB "${CUDA_nvrtc_LIBRARY}" CACHE FILEPATH "")
+if(CUDA_NVRTC_LIB AND NOT CUDA_NVRTC_SHORTHASH)
+  find_package(Python COMPONENTS Interpreter)
+  execute_process(
+    COMMAND Python::Interpreter -c
+    "import hashlib;hash=hashlib.sha256();hash.update(open('${CUDA_NVRTC_LIB}','rb').read());print(hash.hexdigest()[:8])"
+    RESULT_VARIABLE _retval
+    OUTPUT_VARIABLE CUDA_NVRTC_SHORTHASH)
+  if(NOT _retval EQUAL 0)
+    message(WARNING "Failed to compute shorthash for libnvrtc.so")
+    set(CUDA_NVRTC_SHORTHASH "XXXXXXXX")
+  else()
+    string(STRIP "${CUDA_NVRTC_SHORTHASH}" CUDA_NVRTC_SHORTHASH)
+    message(STATUS "${CUDA_NVRTC_LIB} shorthash is ${CUDA_NVRTC_SHORTHASH}")
+  endif()
+endif()
+
+# Create new style imported libraries.
+# Several of these libraries have a hardcoded path if CAFFE2_STATIC_LINK_CUDA
+# is set. This path is where sane CUDA installations have their static
+# libraries installed. This flag should only be used for binary builds, so
+# end-users should never have this flag set.
+
+# cuda
+add_library(caffe2::cuda INTERFACE IMPORTED)
+set_property(
+    TARGET caffe2::cuda PROPERTY INTERFACE_LINK_LIBRARIES
+    CUDA::cuda_driver)
+
+# cudart
+add_library(torch::cudart INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA)
+    set_property(
+        TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cudart_static)
+else()
+    set_property(
+        TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cudart)
+endif()
+
+
+# cublas
+add_library(caffe2::cublas INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+    set_property(
+        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
+        # NOTE: cublas is always linked dynamically
+        CUDA::cublas CUDA::cublasLt)
+    set_property(
+        TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cudart_static rt)
+else()
+    set_property(
+        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cublas CUDA::cublasLt)
+endif()
+
+# cudnn interface
+# static linking is handled by USE_STATIC_CUDNN environment variable
+if(CAFFE2_USE_CUDNN)
+  if(USE_STATIC_CUDNN)
+    set(CUDNN_STATIC ON CACHE BOOL "")
+  else()
+    set(CUDNN_STATIC OFF CACHE BOOL "")
+  endif()
+
+  find_package(CUDNN)
+
+  if(NOT CUDNN_FOUND)
+    message(WARNING
+      "Cannot find cuDNN library. Turning the option off")
+    set(CAFFE2_USE_CUDNN OFF)
+  else()
+    if(CUDNN_VERSION VERSION_LESS "8.1.0")
+      message(FATAL_ERROR "PyTorch requires cuDNN 8.1 and above.")
+    endif()
+  endif()
+
+  add_library(torch::cudnn INTERFACE IMPORTED)
+  target_include_directories(torch::cudnn INTERFACE ${CUDNN_INCLUDE_PATH})
+  if(CUDNN_STATIC AND NOT WIN32)
+    target_link_options(torch::cudnn INTERFACE
+        "-Wl,--exclude-libs,libcudnn_static.a")
+  else()
+    target_link_libraries(torch::cudnn INTERFACE ${CUDNN_LIBRARY_PATH})
+  endif()
+else()
+  message(STATUS "USE_CUDNN is set to 0. Compiling without cuDNN support")
+endif()
+
+if(CAFFE2_USE_CUSPARSELT)
+  find_package(CUSPARSELT)
+
+  if(NOT CUSPARSELT_FOUND)
+    message(WARNING
+      "Cannot find cuSPARSELt library. Turning the option off")
+    set(CAFFE2_USE_CUSPARSELT OFF)
+  else()
+    add_library(torch::cusparselt INTERFACE IMPORTED)
+    target_include_directories(torch::cusparselt INTERFACE ${CUSPARSELT_INCLUDE_PATH})
+    target_link_libraries(torch::cusparselt INTERFACE ${CUSPARSELT_LIBRARY_PATH})
+  endif()
+else()
+  message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
+endif()
+
+if(USE_CUDSS)
+  find_package(CUDSS)
+
+  if(NOT CUDSS_FOUND)
+    message(WARNING
+      "Cannot find CUDSS library. Turning the option off")
+    set(USE_CUDSS OFF)
+  else()
+    add_library(torch::cudss INTERFACE IMPORTED)
+    target_include_directories(torch::cudss INTERFACE ${CUDSS_INCLUDE_PATH})
+    target_link_libraries(torch::cudss INTERFACE ${CUDSS_LIBRARY_PATH})
+  endif()
+else()
+  message(STATUS "USE_CUDSS is set to 0. Compiling without cuDSS support")
+endif()
+
+# cufile
+if(CAFFE2_USE_CUFILE)
+  add_library(torch::cufile INTERFACE IMPORTED)
+  if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+      set_property(
+          TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cuFile_static)
+  else()
+      set_property(
+          TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cuFile)
+  endif()
+else()
+  message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support")
+endif()
+
+# curand
+add_library(caffe2::curand INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+    set_property(
+        TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::curand_static)
+else()
+    set_property(
+        TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::curand)
+endif()
+
+# cufft
+add_library(caffe2::cufft INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+    if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
+      set_property(
+          TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cufft_static_nocallback)
+    else()
+      set_property(
+          TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cufft_static)
+    endif()
+else()
+    set_property(
+        TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cufft)
+endif()
+
+# nvrtc
+add_library(caffe2::nvrtc INTERFACE IMPORTED)
+set_property(
+    TARGET caffe2::nvrtc PROPERTY INTERFACE_LINK_LIBRARIES
+    CUDA::nvrtc caffe2::cuda)
+
+# Add onnx namespace definition to nvcc
+if(ONNX_NAMESPACE)
+  list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=${ONNX_NAMESPACE}")
+else()
+  list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=onnx_c2")
+endif()
+
+# Don't activate VC env again for Ninja generators with MSVC on Windows if CUDAHOSTCXX is not defined
+# by adding --use-local-env.
+if(MSVC AND CMAKE_GENERATOR STREQUAL "Ninja" AND NOT DEFINED ENV{CUDAHOSTCXX})
+  list(APPEND CUDA_NVCC_FLAGS "--use-local-env")
+endif()
+
+# setting nvcc arch flags
+torch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA)
+# CMake 3.18 adds integrated support for architecture selection, but we can't rely on it
+if(DEFINED CMAKE_CUDA_ARCHITECTURES)
+  message(WARNING
+          "pytorch is not compatible with `CMAKE_CUDA_ARCHITECTURES` and will ignore its value. "
+          "Please configure `TORCH_CUDA_ARCH_LIST` instead.")
+  set(CMAKE_CUDA_ARCHITECTURES OFF)
+endif()
+
+list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
+message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA}")
+
+# disable some nvcc diagnostic that appears in boost, glog, glags, opencv, etc.
+foreach(diag cc_clobber_ignored
+             field_without_dll_interface
+             base_class_has_different_dll_interface
+             dll_interface_conflict_none_assumed
+             dll_interface_conflict_dllexport_assumed
+             bad_friend_decl)
+  list(APPEND SUPPRESS_WARNING_FLAGS --diag_suppress=${diag})
+endforeach()
+string(REPLACE ";" "," SUPPRESS_WARNING_FLAGS "${SUPPRESS_WARNING_FLAGS}")
+list(APPEND CUDA_NVCC_FLAGS -Xcudafe ${SUPPRESS_WARNING_FLAGS})
+
+set(CUDA_PROPAGATE_HOST_FLAGS_BLOCKLIST "-Werror")
+if(MSVC)
+  list(APPEND CUDA_NVCC_FLAGS "--Werror" "cross-execution-space-call")
+  list(APPEND CUDA_NVCC_FLAGS "--no-host-device-move-forward")
+endif()
+
+# Debug and Release symbol support
+if(MSVC)
+  if(${CAFFE2_USE_MSVC_STATIC_RUNTIME})
+    string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MTd")
+    string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MT")
+    string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MT")
+    string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MT")
+  else()
+    string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MDd")
+    string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MD")
+    string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MD")
+    string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MD")
+  endif()
+  if(CUDA_NVCC_FLAGS MATCHES "Zi")
+    list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "-FS")
+  endif()
+elseif(CUDA_DEVICE_DEBUG)
+  list(APPEND CUDA_NVCC_FLAGS "-g" "-G")  # -G enables device code debugging symbols
+endif()
+
+# Set expt-relaxed-constexpr to suppress Eigen warnings
+list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
+
+# Set expt-extended-lambda to support lambda on device
+list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda")
+
+foreach(FLAG ${CUDA_NVCC_FLAGS})
+  string(FIND "${FLAG}" " " flag_space_position)
+  if(NOT flag_space_position EQUAL -1)
+    message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'")
+  endif()
+  string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}")
+endforeach()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..186cda1a909ab79431114d1c61de895069255389
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake
@@ -0,0 +1,83 @@
+# ---[ gflags
+
+# We will try to use the config mode first, and then manual find.
+find_package(gflags CONFIG QUIET)
+if(NOT TARGET gflags)
+  find_package(gflags MODULE QUIET)
+endif()
+
+if(TARGET gflags)
+  message(STATUS "Caffe2: Found gflags with new-style gflags target.")
+elseif(GFLAGS_FOUND)
+  message(STATUS "Caffe2: Found gflags with old-style gflag starget.")
+  add_library(gflags UNKNOWN IMPORTED)
+  set_property(
+      TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARY})
+  set_property(
+      TARGET gflags PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+      ${GFLAGS_INCLUDE_DIR})
+else()
+  message(STATUS
+      "Caffe2: Cannot find gflags automatically. Using legacy find.")
+
+  # - Try to find GFLAGS in the legacy way.
+  #
+  # The following variables are optionally searched for defaults
+  #  GFLAGS_ROOT_DIR: Base directory where all GFLAGS components are found
+  #
+  # The following are set after configuration is done:
+  #  GFLAGS_FOUND
+  #  GFLAGS_INCLUDE_DIRS
+  #  GFLAGS_LIBRARIES
+  #  GFLAGS_LIBRARYRARY_DIRS
+  include(FindPackageHandleStandardArgs)
+  set(GFLAGS_ROOT_DIR "" CACHE PATH "Folder contains Gflags")
+
+  # We are testing only a couple of files in the include directories
+  if(WIN32)
+    find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h
+        PATHS ${GFLAGS_ROOT_DIR}/src/windows)
+  else()
+    find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h
+        PATHS ${GFLAGS_ROOT_DIR})
+  endif()
+
+  if(WIN32)
+    find_library(GFLAGS_LIBRARY_RELEASE
+        NAMES libgflags
+        PATHS ${GFLAGS_ROOT_DIR}
+        PATH_SUFFIXES Release)
+
+    find_library(GFLAGS_LIBRARY_DEBUG
+        NAMES libgflags-debug
+        PATHS ${GFLAGS_ROOT_DIR}
+        PATH_SUFFIXES Debug)
+    set(GFLAGS_LIBRARY optimized ${GFLAGS_LIBRARY_RELEASE} debug ${GFLAGS_LIBRARY_DEBUG})
+  else()
+    find_library(GFLAGS_LIBRARY gflags)
+  endif()
+
+  find_package_handle_standard_args(
+      gflags DEFAULT_MSG GFLAGS_INCLUDE_DIR GFLAGS_LIBRARY)
+
+  if(GFLAGS_FOUND)
+    message(
+        STATUS
+        "Caffe2: Found gflags  (include: ${GFLAGS_INCLUDE_DIR}, "
+        "library: ${GFLAGS_LIBRARY})")
+    add_library(gflags UNKNOWN IMPORTED)
+    set_property(
+        TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARY})
+    set_property(
+        TARGET gflags PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+        ${GFLAGS_INCLUDE_DIR})
+  endif()
+endif()
+
+# After above, we should have the gflags target now.
+if(NOT TARGET gflags)
+  message(WARNING
+      "Caffe2: gflags cannot be found. Depending on whether you are building "
+      "Caffe2 or a Caffe2 dependent library, the next warning / error will "
+      "give you more info.")
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/glog.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/glog.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..bb03e81f29e3afed43ba95260cc5c298be881f72
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/glog.cmake
@@ -0,0 +1,70 @@
+# ---[ glog
+
+# We will try to use the config mode first, and then manual find.
+find_package(glog CONFIG QUIET)
+if(NOT TARGET glog::glog)
+  find_package(glog MODULE QUIET)
+endif()
+
+if(TARGET glog::glog)
+  message(STATUS "Caffe2: Found glog with new-style glog target.")
+elseif(GLOG_FOUND)
+  message(
+      STATUS
+      "Caffe2: Found glog with old-style glog starget. Glog never shipped "
+      "old style glog targets, so somewhere in your cmake path there might "
+      "be a custom Findglog.cmake file that got triggered. We will make a "
+      "best effort to create the new style glog target for you.")
+  add_library(glog::glog UNKNOWN IMPORTED)
+  set_property(
+      TARGET glog::glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARY})
+  set_property(
+      TARGET glog::glog PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+      ${GLOG_INCLUDE_DIR})
+else()
+  message(STATUS "Caffe2: Cannot find glog automatically. Using legacy find.")
+
+  # - Try to find Glog
+  #
+  # The following variables are optionally searched for defaults
+  #  GLOG_ROOT_DIR: Base directory where all GLOG components are found
+  #
+  # The following are set after configuration is done:
+  #  GLOG_FOUND
+  #  GLOG_INCLUDE_DIRS
+  #  GLOG_LIBRARIES
+  #  GLOG_LIBRARYRARY_DIRS
+
+  include(FindPackageHandleStandardArgs)
+  set(GLOG_ROOT_DIR "" CACHE PATH "Folder contains Google glog")
+  if(NOT WIN32)
+      find_path(GLOG_INCLUDE_DIR glog/logging.h
+          PATHS ${GLOG_ROOT_DIR})
+  endif()
+
+  find_library(GLOG_LIBRARY glog
+      PATHS ${GLOG_ROOT_DIR}
+      PATH_SUFFIXES lib lib64)
+
+  find_package_handle_standard_args(glog DEFAULT_MSG GLOG_INCLUDE_DIR GLOG_LIBRARY)
+
+  if(GLOG_FOUND)
+    message(STATUS
+        "Caffe2: Found glog (include: ${GLOG_INCLUDE_DIR}, "
+        "library: ${GLOG_LIBRARY})")
+    add_library(glog::glog UNKNOWN IMPORTED)
+    set_property(
+        TARGET glog::glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARY})
+    set_property(
+        TARGET glog::glog PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+        ${GLOG_INCLUDE_DIR})
+  endif()
+endif()
+
+# After above, we should have the glog::glog target now.
+if(NOT TARGET glog::glog)
+  message(WARNING
+      "Caffe2: glog cannot be found. Depending on whether you are building "
+      "Caffe2 or a Caffe2 dependent library, the next warning / error will "
+      "give you more info.")
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..2f6d1fd905aa303cc240b058318acdfb2483e9ad
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake
@@ -0,0 +1,40 @@
+find_package(MKL QUIET)
+
+if(TARGET caffe2::mkl)
+  return()
+endif()
+
+add_library(caffe2::mkl INTERFACE IMPORTED)
+target_include_directories(caffe2::mkl INTERFACE ${MKL_INCLUDE_DIR})
+target_link_libraries(caffe2::mkl INTERFACE ${MKL_LIBRARIES})
+foreach(MKL_LIB IN LISTS MKL_LIBRARIES)
+  if(EXISTS "${MKL_LIB}")
+    get_filename_component(MKL_LINK_DIR "${MKL_LIB}" DIRECTORY)
+    if(IS_DIRECTORY "${MKL_LINK_DIR}")
+      target_link_directories(caffe2::mkl INTERFACE "${MKL_LINK_DIR}")
+    endif()
+  endif()
+endforeach()
+
+# TODO: This is a hack, it will not pick up architecture dependent
+# MKL libraries correctly; see https://github.com/pytorch/pytorch/issues/73008
+set_property(
+  TARGET caffe2::mkl PROPERTY INTERFACE_LINK_DIRECTORIES
+  ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64 ${MKL_ROOT}/lib/intel64_win ${MKL_ROOT}/lib/win-x64)
+
+if(UNIX)
+  if(USE_STATIC_MKL)
+    foreach(MKL_LIB_PATH IN LISTS MKL_LIBRARIES)
+      if(NOT EXISTS "${MKL_LIB_PATH}")
+        continue()
+      endif()
+
+      get_filename_component(MKL_LIB_NAME "${MKL_LIB_PATH}" NAME)
+
+      # Match archive libraries starting with "libmkl_"
+      if(MKL_LIB_NAME MATCHES "^libmkl_" AND MKL_LIB_NAME MATCHES ".a$")
+        target_link_options(caffe2::mkl INTERFACE "-Wl,--exclude-libs,${MKL_LIB_NAME}")
+      endif()
+    endforeach()
+  endif()
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..87935625f9bfb543d1cdc7f2b59f11e8d4a709e7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake
@@ -0,0 +1,18 @@
+set(MKLDNN_USE_NATIVE_ARCH ${USE_NATIVE_ARCH})
+
+if(CPU_AARCH64)
+  include(${CMAKE_CURRENT_LIST_DIR}/ComputeLibrary.cmake)
+endif()
+
+find_package(MKLDNN QUIET)
+
+if(NOT TARGET caffe2::mkldnn)
+  add_library(caffe2::mkldnn INTERFACE IMPORTED)
+endif()
+
+set_property(
+  TARGET caffe2::mkldnn PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+  ${MKLDNN_INCLUDE_DIR})
+set_property(
+  TARGET caffe2::mkldnn PROPERTY INTERFACE_LINK_LIBRARIES
+  ${MKLDNN_LIBRARIES})
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..77ec3622b132dc7a7817716dd24ef986e6ac030d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake
@@ -0,0 +1,92 @@
+# ---[ Protobuf
+
+# We will try to use the config mode first, and then manual find.
+find_package(Protobuf CONFIG QUIET)
+if(NOT Protobuf_FOUND)
+  find_package(Protobuf MODULE QUIET)
+endif()
+
+if((TARGET protobuf::libprotobuf OR TARGET protobuf::libprotobuf-lite) AND TARGET protobuf::protoc)
+  # Hooray. This is the most ideal situation, meaning that you either have a
+  # Protobuf config file installed (like on Windows), or you are using a
+  # modern CMake that ships with a FindProtobuf.cmake file that produces
+  # modern targets.
+  message(STATUS "Caffe2: Found protobuf with new-style protobuf targets.")
+elseif(Protobuf_FOUND OR PROTOBUF_FOUND)
+  # If the modern targets are not present, we will generate them for you for
+  # backward compatibility. This is backported from CMake's new FindProtobuf.cmake
+  # content.
+  if((NOT PROTOBUF_LIBRARY) AND (NOT PROTOBUF_LITE_LIBRARY))
+    message(FATAL_ERROR
+        "Caffe2: Found protobuf with old style targets, but could not find targets."
+        " PROTOBUF_LIBRARY: " ${PROTOBUF_LIBRARY}
+        " PROTOBUF_LITE_LIBRARY: " ${PROTOBUF_LITE_LIBRARY}
+        " Protobuf_LIBRARY: " ${Protobuf_LIBRARY}
+        " Protobuf_LITE_LIBRARY: " ${Protobuf_LITE_LIBRARY})
+  endif()
+  message(STATUS "Caffe2: Found protobuf with old-style protobuf targets.")
+
+  if(PROTOBUF_LIBRARY)
+    if(NOT TARGET protobuf::libprotobuf)
+      add_library(protobuf::libprotobuf UNKNOWN IMPORTED)
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          INTERFACE_INCLUDE_DIRECTORIES "${PROTOBUF_INCLUDE_DIRS}")
+    endif()
+    if(EXISTS "${PROTOBUF_LIBRARY}")
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          IMPORTED_LOCATION "${PROTOBUF_LIBRARY}")
+    endif()
+    if(EXISTS "${PROTOBUF_LIBRARY_RELEASE}")
+      set_property(TARGET protobuf::libprotobuf APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS RELEASE)
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          IMPORTED_LOCATION_RELEASE "${PROTOBUF_LIBRARY_RELEASE}")
+    endif()
+    if(EXISTS "${PROTOBUF_LIBRARY_DEBUG}")
+      set_property(TARGET protobuf::libprotobuf APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS DEBUG)
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          IMPORTED_LOCATION_DEBUG "${PROTOBUF_LIBRARY_DEBUG}")
+    endif()
+  endif()
+
+  if(PROTOBUF_LITE_LIBRARY)
+    if(NOT TARGET protobuf::libprotobuf-lite)
+      add_library(protobuf::libprotobuf-lite UNKNOWN IMPORTED)
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          INTERFACE_INCLUDE_DIRECTORIES "${PROTOBUF_INCLUDE_DIRS}")
+    endif()
+    if(EXISTS "${PROTOBUF_LITE_LIBRARY}")
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          IMPORTED_LOCATION "${PROTOBUF_LITE_LIBRARY}")
+    endif()
+    if(EXISTS "${PROTOBUF_LITE_LIBRARY_RELEASE}")
+      set_property(TARGET protobuf::libprotobuf-lite APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS RELEASE)
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          IMPORTED_LOCATION_RELEASE "${PROTOBUF_LITE_LIBRARY_RELEASE}")
+    endif()
+    if(EXISTS "${PROTOBUF_LITE_LIBRARY_DEBUG}")
+      set_property(TARGET protobuf::libprotobuf-lite APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS DEBUG)
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          IMPORTED_LOCATION_DEBUG "${PROTOBUF_LITE_LIBRARY_DEBUG}")
+    endif()
+  endif()
+
+  if(PROTOBUF_PROTOC_EXECUTABLE)
+    if(NOT TARGET protobuf::protoc)
+      add_executable(protobuf::protoc IMPORTED)
+    endif()
+    set_property(TARGET protobuf::protoc PROPERTY
+        IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE})
+  endif()
+endif()
+
+# After above, we should have the protobuf related target now.
+if((NOT TARGET protobuf::libprotobuf) AND (NOT TARGET protobuf::libprotobuf-lite))
+  message(WARNING
+      "Protobuf cannot be found. Depending on whether you are building Caffe2 "
+      "or a Caffe2 dependent library, the next warning / error will give you "
+      "more info.")
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/utils.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/utils.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..3cdf5fb914b1ddaad115332079cb66a13ac2aea9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/utils.cmake
@@ -0,0 +1,552 @@
+################################################################################################
+# Exclude and prepend functionalities
+function(exclude OUTPUT INPUT)
+set(EXCLUDES ${ARGN})
+foreach(EXCLUDE ${EXCLUDES})
+        list(REMOVE_ITEM INPUT "${EXCLUDE}")
+endforeach()
+set(${OUTPUT} ${INPUT} PARENT_SCOPE)
+endfunction(exclude)
+
+function(prepend OUTPUT PREPEND)
+set(OUT "")
+foreach(ITEM ${ARGN})
+        list(APPEND OUT "${PREPEND}${ITEM}")
+endforeach()
+set(${OUTPUT} ${OUT} PARENT_SCOPE)
+endfunction(prepend)
+
+################################################################################################
+# Parses a version string that might have values beyond major, minor, and patch
+# and set version variables for the library.
+# Usage:
+#   caffe2_parse_version_str( )
+function(caffe2_parse_version_str LIBNAME VERSIONSTR)
+  string(REGEX REPLACE "^([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MAJOR "${VERSIONSTR}")
+  string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MINOR  "${VERSIONSTR}")
+  string(REGEX REPLACE "[0-9]+\\.[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_PATCH "${VERSIONSTR}")
+  set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE)
+  set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE)
+  set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE)
+  set(${LIBNAME}_VERSION "${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}" PARENT_SCOPE)
+endfunction()
+
+###
+# Removes common indentation from a block of text to produce code suitable for
+# setting to `python -c`, or using with pycmd. This allows multiline code to be
+# nested nicely in the surrounding code structure.
+#
+# This function respsects Python_EXECUTABLE if it defined, otherwise it uses
+# `python` and hopes for the best. An error will be thrown if it is not found.
+#
+# Args:
+#     outvar : variable that will hold the stdout of the python command
+#     text   : text to remove indentation from
+#
+function(dedent outvar text)
+  # Use Python_EXECUTABLE if it is defined, otherwise default to python
+  if("${Python_EXECUTABLE}" STREQUAL "")
+    set(_python_exe "python3")
+  else()
+    set(_python_exe "${Python_EXECUTABLE}")
+  endif()
+  set(_fixup_cmd "import sys; from textwrap import dedent; print(dedent(sys.stdin.read()))")
+  file(WRITE "${CMAKE_BINARY_DIR}/indented.txt" "${text}")
+  execute_process(
+    COMMAND "${_python_exe}" -c "${_fixup_cmd}"
+    INPUT_FILE "${CMAKE_BINARY_DIR}/indented.txt"
+    RESULT_VARIABLE _dedent_exitcode
+    OUTPUT_VARIABLE _dedent_text)
+  if(NOT _dedent_exitcode EQUAL 0)
+    message(ERROR " Failed to remove indentation from: \n\"\"\"\n${text}\n\"\"\"
+    Python dedent failed with error code: ${_dedent_exitcode}")
+    message(FATAL_ERROR " Python dedent failed with error code: ${_dedent_exitcode}")
+  endif()
+  # Remove supurflous newlines (artifacts of print)
+  string(STRIP "${_dedent_text}" _dedent_text)
+  set(${outvar} "${_dedent_text}" PARENT_SCOPE)
+endfunction()
+
+
+function(pycmd_no_exit outvar exitcode cmd)
+  # Use Python_EXECUTABLE if it is defined, otherwise default to python
+  if("${Python_EXECUTABLE}" STREQUAL "")
+    set(_python_exe "python")
+  else()
+    set(_python_exe "${Python_EXECUTABLE}")
+  endif()
+  # run the actual command
+  execute_process(
+    COMMAND "${_python_exe}" -c "${cmd}"
+    RESULT_VARIABLE _exitcode
+    OUTPUT_VARIABLE _output)
+  # Remove supurflous newlines (artifacts of print)
+  string(STRIP "${_output}" _output)
+  set(${outvar} "${_output}" PARENT_SCOPE)
+  set(${exitcode} "${_exitcode}" PARENT_SCOPE)
+endfunction()
+
+
+###
+# Helper function to run `python -c ""` and capture the results of stdout
+#
+# Runs a python command and populates an outvar with the result of stdout.
+# Common indentation in the text of `cmd` is removed before the command is
+# executed, so the caller does not need to worry about indentation issues.
+#
+# This function respsects Python_EXECUTABLE if it defined, otherwise it uses
+# `python` and hopes for the best. An error will be thrown if it is not found.
+#
+# Args:
+#     outvar : variable that will hold the stdout of the python command
+#     cmd    : text representing a (possibly multiline) block of python code
+#
+function(pycmd outvar cmd)
+  dedent(_dedent_cmd "${cmd}")
+  pycmd_no_exit(_output _exitcode "${_dedent_cmd}")
+
+  if(NOT _exitcode EQUAL 0)
+    message(ERROR " Failed when running python code: \"\"\"\n${_dedent_cmd}\n\"\"\"")
+    message(FATAL_ERROR " Python command failed with error code: ${_exitcode}")
+  endif()
+  # Remove supurflous newlines (artifacts of print)
+  string(STRIP "${_output}" _output)
+  set(${outvar} "${_output}" PARENT_SCOPE)
+endfunction()
+
+
+##############################################################################
+# Macro to update cached options.
+macro(caffe2_update_option variable value)
+  if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
+    get_property(__help_string CACHE ${variable} PROPERTY HELPSTRING)
+    set(${variable} ${value} CACHE BOOL ${__help_string} FORCE)
+  else()
+    set(${variable} ${value})
+  endif()
+endmacro()
+
+
+##############################################################################
+# Add an interface library definition that is dependent on the source.
+#
+# It's probably easiest to explain why this macro exists, by describing
+# what things would look like if we didn't have this macro.
+#
+# Let's suppose we want to statically link against torch.  We've defined
+# a library in cmake called torch, and we might think that we just
+# target_link_libraries(my-app PUBLIC torch).  This will result in a
+# linker argument 'libtorch.a' getting passed to the linker.
+#
+# Unfortunately, this link command is wrong!  We have static
+# initializers in libtorch.a that would get improperly pruned by
+# the default link settings.  What we actually need is for you
+# to do -Wl,--whole-archive,libtorch.a -Wl,--no-whole-archive to ensure
+# that we keep all symbols, even if they are (seemingly) not used.
+#
+# What caffe2_interface_library does is create an interface library
+# that indirectly depends on the real library, but sets up the link
+# arguments so that you get all of the extra link settings you need.
+# The result is not a "real" library, and so we have to manually
+# copy over necessary properties from the original target.
+#
+# (The discussion above is about static libraries, but a similar
+# situation occurs for dynamic libraries: if no symbols are used from
+# a dynamic library, it will be pruned unless you are --no-as-needed)
+macro(caffe2_interface_library SRC DST)
+  add_library(${DST} INTERFACE)
+  add_dependencies(${DST} ${SRC})
+  # Depending on the nature of the source library as well as the compiler,
+  # determine the needed compilation flags.
+  get_target_property(__src_target_type ${SRC} TYPE)
+  # Depending on the type of the source library, we will set up the
+  # link command for the specific SRC library.
+  if(${__src_target_type} STREQUAL "STATIC_LIBRARY")
+    # In the case of static library, we will need to add whole-static flags.
+    target_link_libraries(${DST} INTERFACE $)
+    # Link all interface link libraries of the src target as well.
+    # For static library, we need to explicitly depend on all the libraries
+    # that are the dependent library of the source library. Note that we cannot
+    # use the populated INTERFACE_LINK_LIBRARIES property, because if one of the
+    # dependent library is not a target, cmake creates a $ wrapper
+    # and then one is not able to find target "src". For more discussions, check
+    #   https://cmake.org/Bug/print_bug_page.php?bug_id=15415
+    #   https://cmake.org/pipermail/cmake-developers/2013-May/019019.html
+    # Specifically the following quote
+    #
+    # """
+    # For STATIC libraries we can define that the PUBLIC/PRIVATE/INTERFACE keys
+    # are ignored for linking and that it always populates both LINK_LIBRARIES
+    # LINK_INTERFACE_LIBRARIES.  Note that for STATIC libraries the
+    # LINK_LIBRARIES property will not be used for anything except build-order
+    # dependencies.
+    # """
+    target_link_libraries(${DST} INTERFACE
+        $)
+  elseif(${__src_target_type} STREQUAL "SHARED_LIBRARY")
+    if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU")
+      target_link_libraries(${DST} INTERFACE
+          "-Wl,--no-as-needed,\"$\" -Wl,--as-needed")
+    else()
+      target_link_libraries(${DST} INTERFACE ${SRC})
+    endif()
+    # Link all interface link libraries of the src target as well.
+    # For shared libraries, we can simply depend on the INTERFACE_LINK_LIBRARIES
+    # property of the target.
+    target_link_libraries(${DST} INTERFACE
+        $)
+  else()
+    message(FATAL_ERROR
+        "You made a CMake build file error: target " ${SRC}
+        " must be of type either STATIC_LIBRARY or SHARED_LIBRARY. However, "
+        "I got " ${__src_target_type} ".")
+  endif()
+  # For all other interface properties, manually inherit from the source target.
+  set_target_properties(${DST} PROPERTIES
+    INTERFACE_COMPILE_DEFINITIONS
+    $
+    INTERFACE_COMPILE_OPTIONS
+    $
+    INTERFACE_INCLUDE_DIRECTORIES
+    $
+    INTERFACE_SYSTEM_INCLUDE_DIRECTORIES
+    $)
+endmacro()
+
+
+##############################################################################
+# Creating a Caffe2 binary target with sources specified with relative path.
+# Usage:
+#   caffe2_binary_target(target_name_or_src  [] [] ...)
+# If only target_name_or_src is specified, this target is build with one single
+# source file and the target name is autogen from the filename. Otherwise, the
+# target name is given by the first argument and the rest are the source files
+# to build the target.
+function(caffe2_binary_target target_name_or_src)
+  # https://cmake.org/cmake/help/latest/command/function.html
+  # Checking that ARGC is greater than # is the only way to ensure
+  # that ARGV# was passed to the function as an extra argument.
+  if(ARGC GREATER 1)
+    set(__target ${target_name_or_src})
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}")
+  else()
+    get_filename_component(__target ${target_name_or_src} NAME_WE)
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${target_name_or_src}")
+  endif()
+  add_executable(${__target} ${__srcs})
+  target_link_libraries(${__target} torch_library)
+  # If we have Caffe2_MODULES defined, we will also link with the modules.
+  if(DEFINED Caffe2_MODULES)
+    target_link_libraries(${__target} ${Caffe2_MODULES})
+  endif()
+  install(TARGETS ${__target} DESTINATION bin)
+endfunction()
+
+function(caffe2_hip_binary_target target_name_or_src)
+  if(ARGC GREATER 1)
+    set(__target ${target_name_or_src})
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}")
+  else()
+    get_filename_component(__target ${target_name_or_src} NAME_WE)
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${target_name_or_src}")
+  endif()
+
+  caffe2_binary_target(${target_name_or_src})
+
+  target_compile_options(${__target} PRIVATE ${HIP_CXX_FLAGS})
+  target_include_directories(${__target} PRIVATE ${Caffe2_HIP_INCLUDE})
+endfunction()
+
+
+##############################################################################
+# Multiplex between adding libraries for CUDA versus HIP (AMD Software Stack).
+# Usage:
+#   torch_cuda_based_add_library(cuda_target)
+#
+macro(torch_cuda_based_add_library cuda_target)
+  if(USE_ROCM)
+    hip_add_library(${cuda_target} ${ARGN})
+  elseif(USE_CUDA)
+    add_library(${cuda_target} ${ARGN})
+  else()
+  endif()
+endmacro()
+
+##############################################################################
+# Get the HIP arch flags specified by PYTORCH_ROCM_ARCH.
+# Usage:
+#   torch_hip_get_arch_list(variable_to_store_flags)
+#
+macro(torch_hip_get_arch_list store_var)
+  if(DEFINED ENV{PYTORCH_ROCM_ARCH})
+    set(_TMP $ENV{PYTORCH_ROCM_ARCH})
+  else()
+    # Use arch of installed GPUs as default
+    execute_process(COMMAND "rocm_agent_enumerator" COMMAND bash "-c" "grep -v gfx000 | sort -u | xargs | tr -d '\n'"
+                    RESULT_VARIABLE ROCM_AGENT_ENUMERATOR_RESULT
+                    OUTPUT_VARIABLE ROCM_ARCH_INSTALLED)
+    if(NOT ROCM_AGENT_ENUMERATOR_RESULT EQUAL 0)
+      message(FATAL_ERROR " Could not detect ROCm arch for GPUs on machine. Result: '${ROCM_AGENT_ENUMERATOR_RESULT}'")
+    endif()
+    set(_TMP ${ROCM_ARCH_INSTALLED})
+  endif()
+  string(REPLACE " " ";" ${store_var} "${_TMP}")
+endmacro()
+
+##############################################################################
+# Get the XPU arch flags specified by TORCH_XPU_ARCH_LIST.
+# Usage:
+#   torch_xpu_get_arch_list(variable_to_store_flags)
+#
+macro(torch_xpu_get_arch_list store_var)
+  if(DEFINED ENV{TORCH_XPU_ARCH_LIST})
+    set(${store_var} $ENV{TORCH_XPU_ARCH_LIST})
+  endif()
+endmacro()
+
+##############################################################################
+# Get the NVCC arch flags specified by TORCH_CUDA_ARCH_LIST and CUDA_ARCH_NAME.
+# Usage:
+#   torch_cuda_get_nvcc_gencode_flag(variable_to_store_flags)
+#
+macro(torch_cuda_get_nvcc_gencode_flag store_var)
+  # setting nvcc arch flags
+  # We need to support the explicitly and conveniently defined TORCH_CUDA_ARCH_LIST
+  if((NOT DEFINED TORCH_CUDA_ARCH_LIST) AND (DEFINED ENV{TORCH_CUDA_ARCH_LIST}))
+    set(TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST})
+  endif()
+  if(DEFINED CUDA_ARCH_NAME)
+    message(WARNING
+        "CUDA_ARCH_NAME is no longer used. Use TORCH_CUDA_ARCH_LIST instead. "
+        "Right now, CUDA_ARCH_NAME is ${CUDA_ARCH_NAME} and "
+        "TORCH_CUDA_ARCH_LIST is ${TORCH_CUDA_ARCH_LIST}.")
+    if(NOT TORCH_CUDA_ARCH_LIST)
+      set(TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME})
+    else()
+      list(APPEND TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME})
+    endif()
+  endif()
+
+  # Invoke cuda_select_nvcc_arch_flags from proper cmake FindCUDA.
+  cuda_select_nvcc_arch_flags(${store_var} ${TORCH_CUDA_ARCH_LIST})
+endmacro()
+
+
+##############################################################################
+# Add standard compile options.
+# Usage:
+#   torch_compile_options(lib_name)
+function(torch_compile_options libname)
+  set_property(TARGET ${libname} PROPERTY CXX_STANDARD 17)
+
+  # until they can be unified, keep these lists synced with setup.py
+  if(MSVC)
+
+    if(MSVC_Z7_OVERRIDE)
+      set(MSVC_DEBINFO_OPTION "/Z7")
+    else()
+      set(MSVC_DEBINFO_OPTION "/Zi")
+    endif()
+
+    if(${MSVC_TOOLSET_VERSION} GREATER_EQUAL 142)
+      # Add /permissive- flag for conformance mode to the compiler.
+      # This will force more strict check to the code standard.
+      # 1. From MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/permissive-standards-conformance?view=msvc-170#remarks
+      #    By default, the /permissive- option is set in new projects created by Visual Studio 2017 version 15.5 and later versions.
+      #    We set the /permissive- flag from VS 2019 (MSVC_TOOLSET_VERSION 142) to avoid compiling issues for old toolkit.
+      # 2. For MSVC VERSION: https://cmake.org/cmake/help/latest/variable/MSVC_TOOLSET_VERSION.html
+      target_compile_options(${libname} PUBLIC $<$:/permissive->)
+    endif()
+    # This option enables a token-based preprocessor that conforms to C99 and C++11 and later standards.
+    # This option is available since VS 2017.
+    # For MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/zc-preprocessor
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor" PARENT_SCOPE)
+
+    target_compile_options(${libname} PUBLIC
+      $<$:
+        ${MSVC_RUNTIME_LIBRARY_OPTION}
+        $<$,$>:${MSVC_DEBINFO_OPTION}>
+        /EHsc
+        /bigobj>
+      )
+  else()
+    set(private_compile_options
+      -Wall
+      -Wextra
+      -Wdeprecated
+      -Wunused
+      -Wno-unused-parameter
+      -Wno-missing-field-initializers
+      -Wno-array-bounds
+      -Wno-unknown-pragmas
+      -Wno-strict-overflow
+      -Wno-strict-aliasing
+      )
+    if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+      list(APPEND private_compile_options -Wredundant-move)
+      # -Wno-interference-size only exists in GCC 12+
+      if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12)
+        list(APPEND private_compile_options -Wno-interference-size)
+      endif()
+    endif()
+    if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+      list(APPEND private_compile_options -Wextra-semi -Wmove)
+    else()
+      list(APPEND private_compile_options
+        # Considered to be flaky.  See the discussion at
+        # https://github.com/pytorch/pytorch/pull/9608
+        -Wno-maybe-uninitialized)
+    endif()
+
+    if(WERROR)
+      list(APPEND private_compile_options
+        -Werror
+        -Werror=ignored-attributes
+        -Werror=inconsistent-missing-override
+        -Werror=inconsistent-missing-destructor-override
+        -Werror=pedantic
+        -Werror=unused
+        -Wno-error=unused-parameter
+      )
+      if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+        list(APPEND private_compile_options -Werror=unused-but-set-variable)
+      endif()
+    endif()
+  endif()
+
+
+  target_compile_options(${libname} PRIVATE
+      $<$:${private_compile_options}>)
+  if(USE_CUDA)
+    foreach(option IN LISTS private_compile_options)
+      if(CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "GNU")
+        if("${option}" STREQUAL "-Wextra-semi")
+          continue()
+        endif()
+        if("${option}" STREQUAL "-Wunused-private-field")
+          continue()
+        endif()
+      endif()
+      target_compile_options(${libname} PRIVATE $<$:-Xcompiler ${option}>)
+    endforeach()
+  endif()
+
+  if(NOT WIN32 AND NOT USE_ASAN)
+    # Enable hidden visibility by default to make it easier to debug issues with
+    # TORCH_API annotations. Hidden visibility with selective default visibility
+    # behaves close enough to Windows' dllimport/dllexport.
+    #
+    # Unfortunately, hidden visibility messes up some ubsan warnings because
+    # templated classes crossing library boundary get duplicated (but identical)
+    # definitions. It's easier to just disable it.
+    target_compile_options(${libname} PRIVATE
+        $<$: -fvisibility=hidden>)
+  endif()
+
+endfunction()
+
+##############################################################################
+# Set old-style FindCuda.cmake compile flags from modern CMake cuda flags.
+# Usage:
+#   torch_update_find_cuda_flags()
+function(torch_update_find_cuda_flags)
+  # Convert -O2 -Xcompiler="-O2 -Wall" to "-O2;-Xcompiler=-O2,-Wall"
+  if(USE_CUDA)
+    separate_arguments(FLAGS UNIX_COMMAND "${CMAKE_CUDA_FLAGS}")
+    string(REPLACE " " "," FLAGS "${FLAGS}")
+    set(CUDA_NVCC_FLAGS ${FLAGS} PARENT_SCOPE)
+
+    separate_arguments(FLAGS_DEBUG UNIX_COMMAND "${CMAKE_CUDA_FLAGS_DEBUG}")
+    string(REPLACE " " "," FLAGS_DEBUG "${FLAGS_DEBUG}")
+    set(CUDA_NVCC_FLAGS_DEBUG "${FLAGS_DEBUG}" PARENT_SCOPE)
+
+    separate_arguments(FLAGS_RELEASE UNIX_COMMAND "${CMAKE_CUDA_FLAGS_RELEASE}")
+    string(REPLACE " " "," FLAGS_RELEASE "${FLAGS_RELEASE}")
+    set(CUDA_NVCC_FLAGS_RELEASE "${FLAGS_RELEASE}" PARENT_SCOPE)
+
+    separate_arguments(FLAGS_MINSIZEREL UNIX_COMMAND "${CMAKE_CUDA_FLAGS_MINSIZEREL}")
+    string(REPLACE " " "," FLAGS_MINSIZEREL "${FLAGS_MINSIZEREL}")
+    set(CUDA_NVCC_FLAGS_MINSIZEREL "${FLAGS_MINSIZEREL}" PARENT_SCOPE)
+
+    separate_arguments(FLAGS_RELWITHDEBINFO UNIX_COMMAND "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")
+    string(REPLACE " " "," FLAGS_RELWITHDEBINFO "${FLAGS_RELWITHDEBINFO}")
+    set(CUDA_NVCC_FLAGS_RELWITHDEBINFO "${FLAGS_RELWITHDEBINFO}" PARENT_SCOPE)
+
+    message(STATUS "Converting CMAKE_CUDA_FLAGS to CUDA_NVCC_FLAGS:\n"
+                    "    CUDA_NVCC_FLAGS                = ${FLAGS}\n"
+                    "    CUDA_NVCC_FLAGS_DEBUG          = ${FLAGS_DEBUG}\n"
+                    "    CUDA_NVCC_FLAGS_RELEASE        = ${FLAGS_RELEASE}\n"
+                    "    CUDA_NVCC_FLAGS_RELWITHDEBINFO = ${FLAGS_RELWITHDEBINFO}\n"
+                    "    CUDA_NVCC_FLAGS_MINSIZEREL     = ${FLAGS_MINSIZEREL}")
+  endif()
+endfunction()
+
+include(CheckCXXCompilerFlag)
+include(CheckCCompilerFlag)
+include(CheckLinkerFlag)
+
+##############################################################################
+# CHeck if given flag is supported and append it to provided outputvar
+# Also define HAS_UPPER_CASE_FLAG_NAME variable
+# Usage:
+#   append_cxx_flag_if_supported("-Werror" CMAKE_CXX_FLAGS)
+function(append_cxx_flag_if_supported flag outputvar)
+    string(TOUPPER "HAS${flag}" _FLAG_NAME)
+    string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
+    # GCC silents unknown -Wno-XXX flags, so we detect the corresponding -WXXX.
+    if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+      string(REGEX REPLACE "Wno-" "W" new_flag "${flag}")
+    else()
+      set(new_flag ${flag})
+    endif()
+    check_cxx_compiler_flag("${new_flag}" ${_FLAG_NAME})
+    if(${_FLAG_NAME})
+        string(APPEND ${outputvar} " ${flag}")
+        set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
+    endif()
+endfunction()
+
+function(append_c_flag_if_supported flag outputvar)
+    string(TOUPPER "HAS${flag}" _FLAG_NAME)
+    string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
+
+    # GCC silences unknown -Wno-XXX flags, so test the corresponding -WXXX.
+    if(CMAKE_C_COMPILER_ID STREQUAL "GNU")
+        string(REGEX REPLACE "^Wno-" "W" new_flag "${flag}")
+    else()
+        set(new_flag "${flag}")
+    endif()
+
+    check_c_compiler_flag("${new_flag}" ${_FLAG_NAME})
+    if(${_FLAG_NAME})
+        string(APPEND ${outputvar} " ${flag}")
+        set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
+    endif()
+endfunction()
+
+function(target_compile_options_if_supported target flag)
+  set(_compile_options "")
+  append_cxx_flag_if_supported("${flag}" _compile_options)
+  if(NOT "${_compile_options}" STREQUAL "")
+    target_compile_options(${target} PRIVATE ${flag})
+  endif()
+endfunction()
+
+# Check if a global link option is supported
+function(add_link_options_if_supported flag)
+  check_linker_flag(C "LINKER:${flag}" _supported)
+  if("${_supported}")
+    add_link_options("LINKER:${flag}")
+  else()
+    message(WARNING "Attempted to use unsupported link option : ${flag}.")
+  endif()
+endfunction()
+
+function(target_link_options_if_supported tgt flag)
+  check_linker_flag(C "LINKER:${flag}" _supported)
+  if("${_supported}")
+    target_link_options("${tgt}" PRIVATE "LINKER:${flag}")
+  else()
+    message(WARNING "Attempted to use unsupported link option : ${flag}.")
+  endif()
+endfunction()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..b39e31d0ade8aa52206784ae93f37238a3b7fd11
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake
@@ -0,0 +1,56 @@
+# ---[ xpu
+
+# Poor man's include guard
+if(TARGET torch::xpurt)
+  return()
+endif()
+
+set(XPU_HOST_CXX_FLAGS)
+
+# Find SYCL library.
+find_package(SYCLToolkit REQUIRED)
+if(NOT SYCL_FOUND)
+  set(PYTORCH_FOUND_XPU FALSE)
+  # Exit early to avoid populating XPU_HOST_CXX_FLAGS.
+  return()
+endif()
+set(PYTORCH_FOUND_XPU TRUE)
+
+# SYCL library interface
+add_library(torch::sycl INTERFACE IMPORTED)
+
+set_property(
+    TARGET torch::sycl PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+    ${SYCL_INCLUDE_DIR})
+set_property(
+    TARGET torch::sycl PROPERTY INTERFACE_LINK_LIBRARIES
+    ${SYCL_LIBRARY})
+
+# xpurt
+add_library(torch::xpurt INTERFACE IMPORTED)
+set_property(
+    TARGET torch::xpurt PROPERTY INTERFACE_LINK_LIBRARIES
+    torch::sycl)
+
+# setting xpu arch flags
+torch_xpu_get_arch_list(XPU_ARCH_FLAGS)
+# propagate to torch-xpu-ops
+set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS})
+
+# Ensure USE_XPU is enabled.
+string(APPEND XPU_HOST_CXX_FLAGS " -DUSE_XPU")
+string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}")
+
+if(DEFINED ENV{XPU_ENABLE_KINETO})
+  set(XPU_ENABLE_KINETO TRUE)
+else()
+  set(XPU_ENABLE_KINETO FALSE)
+endif()
+
+if(WIN32)
+  if(${SYCL_COMPILER_VERSION} GREATER_EQUAL 20250101)
+    set(XPU_ENABLE_KINETO TRUE)
+  endif()
+else()
+  set(XPU_ENABLE_KINETO TRUE)
+endif()
\ No newline at end of file
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..b59f8ceca10f56aaad16d71c32979919ea0537c1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake
@@ -0,0 +1,39 @@
+#----------------------------------------------------------------
+# Generated CMake target import file for configuration "Release".
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Import target "tensorpipe_uv" for configuration "Release"
+set_property(TARGET tensorpipe_uv APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(tensorpipe_uv PROPERTIES
+  IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "C"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libtensorpipe_uv.a"
+  )
+
+list(APPEND _cmake_import_check_targets tensorpipe_uv )
+list(APPEND _cmake_import_check_files_for_tensorpipe_uv "${_IMPORT_PREFIX}/lib64/libtensorpipe_uv.a" )
+
+# Import target "tensorpipe" for configuration "Release"
+set_property(TARGET tensorpipe APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(tensorpipe PROPERTIES
+  IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libtensorpipe.a"
+  )
+
+list(APPEND _cmake_import_check_targets tensorpipe )
+list(APPEND _cmake_import_check_files_for_tensorpipe "${_IMPORT_PREFIX}/lib64/libtensorpipe.a" )
+
+# Import target "tensorpipe_cuda" for configuration "Release"
+set_property(TARGET tensorpipe_cuda APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(tensorpipe_cuda PROPERTIES
+  IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libtensorpipe_cuda.a"
+  )
+
+list(APPEND _cmake_import_check_targets tensorpipe_cuda )
+list(APPEND _cmake_import_check_files_for_tensorpipe_cuda "${_IMPORT_PREFIX}/lib64/libtensorpipe_cuda.a" )
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..26ba6741ec29a4a4940154884073da6fc469553d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets.cmake
@@ -0,0 +1,122 @@
+# Generated by CMake
+
+if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
+   message(FATAL_ERROR "CMake >= 2.8.12 required")
+endif()
+if(CMAKE_VERSION VERSION_LESS "2.8.12")
+   message(FATAL_ERROR "CMake >= 2.8.12 required")
+endif()
+cmake_policy(PUSH)
+cmake_policy(VERSION 2.8.12...4.0)
+#----------------------------------------------------------------
+# Generated CMake target import file.
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Protect against multiple inclusion, which would fail when already imported targets are added once more.
+set(_cmake_targets_defined "")
+set(_cmake_targets_not_defined "")
+set(_cmake_expected_targets "")
+foreach(_cmake_expected_target IN ITEMS tensorpipe_uv tensorpipe tensorpipe_cuda)
+  list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
+  if(TARGET "${_cmake_expected_target}")
+    list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
+  else()
+    list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
+  endif()
+endforeach()
+unset(_cmake_expected_target)
+if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
+  unset(_cmake_targets_defined)
+  unset(_cmake_targets_not_defined)
+  unset(_cmake_expected_targets)
+  unset(CMAKE_IMPORT_FILE_VERSION)
+  cmake_policy(POP)
+  return()
+endif()
+if(NOT _cmake_targets_defined STREQUAL "")
+  string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
+  string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
+  message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
+endif()
+unset(_cmake_targets_defined)
+unset(_cmake_targets_not_defined)
+unset(_cmake_expected_targets)
+
+
+# Compute the installation prefix relative to this file.
+get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+if(_IMPORT_PREFIX STREQUAL "/")
+  set(_IMPORT_PREFIX "")
+endif()
+
+# Create imported target tensorpipe_uv
+add_library(tensorpipe_uv STATIC IMPORTED)
+
+set_target_properties(tensorpipe_uv PROPERTIES
+  INTERFACE_LINK_LIBRARIES "\$;\$;\$;\$"
+)
+
+# Create imported target tensorpipe
+add_library(tensorpipe STATIC IMPORTED)
+
+set_target_properties(tensorpipe PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "\$"
+)
+
+# Create imported target tensorpipe_cuda
+add_library(tensorpipe_cuda STATIC IMPORTED)
+
+set_target_properties(tensorpipe_cuda PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "/usr/local/cuda/include"
+  INTERFACE_LINK_LIBRARIES "tensorpipe;/usr/local/cuda/lib64/libcudart.so"
+)
+
+# Load information for each installed configuration.
+file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/TensorpipeTargets-*.cmake")
+foreach(_cmake_config_file IN LISTS _cmake_config_files)
+  include("${_cmake_config_file}")
+endforeach()
+unset(_cmake_config_file)
+unset(_cmake_config_files)
+
+# Cleanup temporary variables.
+set(_IMPORT_PREFIX)
+
+# Loop over all imported files and verify that they actually exist
+foreach(_cmake_target IN LISTS _cmake_import_check_targets)
+  if(CMAKE_VERSION VERSION_LESS "3.28"
+      OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
+      OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
+    foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
+      if(NOT EXISTS "${_cmake_file}")
+        message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
+   \"${_cmake_file}\"
+but this file does not exist.  Possible reasons include:
+* The file was deleted, renamed, or moved to another location.
+* An install or uninstall procedure did not complete successfully.
+* The installation package was faulty and contained
+   \"${CMAKE_CURRENT_LIST_FILE}\"
+but not all the files it references.
+")
+      endif()
+    endforeach()
+  endif()
+  unset(_cmake_file)
+  unset("_cmake_import_check_files_for_${_cmake_target}")
+endforeach()
+unset(_cmake_target)
+unset(_cmake_import_check_targets)
+
+# This file does not depend on other imported targets which have
+# been exported from the same project but in a separate export set.
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
+cmake_policy(POP)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfig.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfig.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..83dc0fd9eb073ff05285b2a3f7a41d745a123899
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfig.cmake
@@ -0,0 +1,170 @@
+# FindTorch
+# -------
+#
+# Finds the Torch library
+#
+# This will define the following variables:
+#
+#   TORCH_FOUND        -- True if the system has the Torch library
+#   TORCH_INCLUDE_DIRS -- The include directories for torch
+#   TORCH_LIBRARIES    -- Libraries to link against
+#   TORCH_CXX_FLAGS    -- Additional (required) compiler flags
+#
+# and the following imported targets:
+#
+#   torch
+macro(append_torchlib_if_found)
+  foreach (_arg ${ARGN})
+    find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    if(${_arg}_LIBRARY)
+      list(APPEND TORCH_LIBRARIES ${${_arg}_LIBRARY})
+    else()
+      message(WARNING "static library ${${_arg}_LIBRARY} not found.")
+    endif()
+  endforeach()
+endmacro()
+
+macro(append_wholearchive_lib_if_found)
+  foreach (_arg ${ARGN})
+    find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    if(${_arg}_LIBRARY)
+      if(APPLE)
+        list(APPEND TORCH_LIBRARIES "-Wl,-force_load,${${_arg}_LIBRARY}")
+      elseif(MSVC)
+        list(APPEND TORCH_LIBRARIES "-WHOLEARCHIVE:${${_arg}_LIBRARY}")
+      else()
+        # Linux
+        list(APPEND TORCH_LIBRARIES "-Wl,--whole-archive ${${_arg}_LIBRARY} -Wl,--no-whole-archive")
+      endif()
+    else()
+      message(WARNING "static library ${${_arg}_LIBRARY} not found.")
+    endif()
+  endforeach()
+endmacro()
+
+include(FindPackageHandleStandardArgs)
+
+if(DEFINED ENV{TORCH_INSTALL_PREFIX})
+  set(TORCH_INSTALL_PREFIX $ENV{TORCH_INSTALL_PREFIX})
+else()
+  # Assume we are in /share/cmake/Torch/TorchConfig.cmake
+  get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
+  get_filename_component(TORCH_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
+endif()
+
+# Include directories.
+if(EXISTS "${TORCH_INSTALL_PREFIX}/include")
+  set(TORCH_INCLUDE_DIRS
+    ${TORCH_INSTALL_PREFIX}/include
+    ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
+else()
+  set(TORCH_INCLUDE_DIRS
+    ${TORCH_INSTALL_PREFIX}/include
+    ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
+endif()
+
+# Library dependencies.
+if(ON)
+  find_package(Caffe2 REQUIRED PATHS ${CMAKE_CURRENT_LIST_DIR}/../Caffe2)
+  set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS})
+  append_torchlib_if_found(c10)
+else()
+  add_library(torch STATIC IMPORTED) # set imported_location at the bottom
+  #library need whole archive
+  append_wholearchive_lib_if_found(torch torch_cpu)
+  if(ON)
+    append_wholearchive_lib_if_found(torch_cuda c10_cuda)
+  endif()
+  if(OFF)
+    append_wholearchive_lib_if_found(torch_xpu c10_xpu)
+  endif()
+
+  # We need manually add dependent libraries when they are not linked into the
+  # shared library.
+  # TODO: this list might be incomplete.
+  append_torchlib_if_found(c10)
+
+  if(ON)
+    append_torchlib_if_found(nnpack)
+  endif()
+
+  if(ON)
+    append_torchlib_if_found(pytorch_qnnpack)
+  endif()
+
+  if(ON)
+    append_torchlib_if_found(XNNPACK)
+    append_torchlib_if_found(microkernels-prod)
+  endif()
+
+  if(OFF)
+    append_torchlib_if_found(kleidiai)
+  endif()
+
+  append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc)
+  append_torchlib_if_found(onnx onnx_proto)
+
+  append_torchlib_if_found(fmt)
+  append_torchlib_if_found(cpuinfo clog)
+
+  append_torchlib_if_found(eigen_blas)
+  append_torchlib_if_found(pthreadpool)
+
+  if(ON)
+    append_torchlib_if_found(fbgemm)
+  endif()
+
+  if(ON)
+    append_torchlib_if_found(dnnl mkldnn)
+  endif()
+
+  append_torchlib_if_found(sleef asmjit)
+endif()
+
+if(1)
+  append_torchlib_if_found(kineto)
+endif()
+
+if(ON)
+  if(MSVC)
+    find_library(CAFFE2_NVRTC_LIBRARY caffe2_nvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    list(APPEND TORCH_CUDA_LIBRARIES ${CAFFE2_NVRTC_LIBRARY})
+  else()
+    set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
+  endif()
+  if(TARGET torch::nvtoolsext)
+    list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
+  endif()
+
+  if(ON)
+    find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY} ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
+  endif()
+  list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES})
+endif()
+
+if(OFF AND ON)
+    append_torchlib_if_found(c10_xpu torch_xpu)
+endif()
+
+find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
+# the statements below changes target properties on
+# - the imported target from Caffe2Targets.cmake in shared library mode (see the find_package above)
+#    - this is untested whether it is the correct (or desired) methodology in CMake
+# - the imported target created in this file in static library mode
+if(NOT ON)
+  # do not set this property on the shared library target, as it will cause confusion in some builds
+  # as the configuration specific property is set in the Caffe2Targets.cmake file
+  set_target_properties(torch PROPERTIES
+      IMPORTED_LOCATION "${TORCH_LIBRARY}"
+  )
+endif()
+set_target_properties(torch PROPERTIES
+    INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
+    CXX_STANDARD 17
+)
+if(TORCH_CXX_FLAGS)
+  set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}")
+endif()
+
+find_package_handle_standard_args(Torch DEFAULT_MSG TORCH_LIBRARY TORCH_INCLUDE_DIRS)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..c7379319b36ec11b13d940841cde5ff9d17025ce
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake
@@ -0,0 +1,11 @@
+set(PACKAGE_VERSION "2.10.0")
+
+# Check whether the requested PACKAGE_FIND_VERSION is compatible
+if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}")
+  set(PACKAGE_VERSION_COMPATIBLE FALSE)
+else()
+  set(PACKAGE_VERSION_COMPATIBLE TRUE)
+  if("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}")
+    set(PACKAGE_VERSION_EXACT TRUE)
+  endif()
+endif()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d73ec2e70364336fc8c7396d99b9d92ff8ec956
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6749a92c6fc1525ea95c7d4d1e398229ab10b7a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__init__.py
@@ -0,0 +1,28 @@
+from .windows import (
+    bartlett,
+    blackman,
+    cosine,
+    exponential,
+    gaussian,
+    general_cosine,
+    general_hamming,
+    hamming,
+    hann,
+    kaiser,
+    nuttall,
+)
+
+
+__all__ = [
+    "bartlett",
+    "blackman",
+    "cosine",
+    "exponential",
+    "gaussian",
+    "general_cosine",
+    "general_hamming",
+    "hamming",
+    "hann",
+    "kaiser",
+    "nuttall",
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b500b6594fadd1d84e6177c8a6cec72eb1069e0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd82a2811aa69f97a645f33c4ba7492b2937f007
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/windows.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/windows.py
new file mode 100644
index 0000000000000000000000000000000000000000..cda60aadfe1d6208354b045a86700e858cc946f0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/signal/windows/windows.py
@@ -0,0 +1,883 @@
+# mypy: allow-untyped-defs
+from collections.abc import Callable, Iterable
+from math import sqrt
+from typing import TypeVar
+
+import torch
+from torch import Tensor
+from torch._torch_docs import factory_common_args, merge_dicts, parse_kwargs
+
+
+__all__ = [
+    "bartlett",
+    "blackman",
+    "cosine",
+    "exponential",
+    "gaussian",
+    "general_cosine",
+    "general_hamming",
+    "hamming",
+    "hann",
+    "kaiser",
+    "nuttall",
+]
+
+_T = TypeVar("_T")
+
+window_common_args = merge_dicts(
+    parse_kwargs(
+        """
+    M (int): the length of the window.
+        In other words, the number of points of the returned window.
+    sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
+        If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
+"""
+    ),
+    factory_common_args,
+    {
+        "normalization": "The window is normalized to 1 (maximum value is 1). However, the 1 doesn't appear if "
+        ":attr:`M` is even and :attr:`sym` is `True`.",
+    },
+)
+
+
+def _add_docstr(*args: str) -> Callable[[_T], _T]:
+    r"""Adds docstrings to a given decorated function.
+
+    Specially useful when then docstrings needs string interpolation, e.g., with
+    str.format().
+    REMARK: Do not use this function if the docstring doesn't need string
+    interpolation, just write a conventional docstring.
+
+    Args:
+        args (str):
+    """
+
+    def decorator(o: _T) -> _T:
+        o.__doc__ = "".join(args)
+        return o
+
+    return decorator
+
+
+def _window_function_checks(
+    function_name: str, M: int, dtype: torch.dtype, layout: torch.layout
+) -> None:
+    r"""Performs common checks for all the defined windows.
+    This function should be called before computing any window.
+
+    Args:
+        function_name (str): name of the window function.
+        M (int): length of the window.
+        dtype (:class:`torch.dtype`): the desired data type of returned tensor.
+        layout (:class:`torch.layout`): the desired layout of returned tensor.
+    """
+    if M < 0:
+        raise ValueError(
+            f"{function_name} requires non-negative window length, got M={M}"
+        )
+    if layout is not torch.strided:
+        raise ValueError(
+            f"{function_name} is implemented for strided tensors only, got: {layout}"
+        )
+    if dtype not in [torch.float32, torch.float64]:
+        raise ValueError(
+            f"{function_name} expects float32 or float64 dtypes, got: {dtype}"
+        )
+
+
+@_add_docstr(
+    r"""
+Computes a window with an exponential waveform.
+Also known as Poisson window.
+
+The exponential window is defined as follows:
+
+.. math::
+    w_n = \exp{\left(-\frac{|n - c|}{\tau}\right)}
+
+where `c` is the ``center`` of the window.
+    """,
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    center (float, optional): where the center of the window will be located.
+        Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`.
+    tau (float, optional): the decay value.
+        Tau is generally associated with a percentage, that means, that the value should
+        vary within the interval (0, 100]. If tau is 100, it is considered the uniform window.
+        Default: 1.0.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric exponential window of size 10 and with a decay value of 1.0.
+    >>> # The center will be at (M - 1) / 2, where M is 10.
+    >>> torch.signal.windows.exponential(10)
+    tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111])
+
+    >>> # Generates a periodic exponential window and decay factor equal to .5
+    >>> torch.signal.windows.exponential(10, sym=False,tau=.5)
+    tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
+    """.format(**window_common_args),
+)
+def exponential(
+    M: int,
+    *,
+    center: float | None = None,
+    tau: float = 1.0,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("exponential", M, dtype, layout)
+
+    if tau <= 0:
+        raise ValueError(f"Tau must be positive, got: {tau} instead.")
+
+    if sym and center is not None:
+        raise ValueError("Center must be None for symmetric windows")
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if center is None:
+        center = (M if not sym and M > 1 else M - 1) / 2.0
+
+    constant = 1 / tau
+
+    k = torch.linspace(
+        start=-center * constant,
+        end=(-center + (M - 1)) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.exp(-torch.abs(k))
+
+
+@_add_docstr(
+    r"""
+Computes a window with a simple cosine waveform, following the same implementation as SciPy.
+This window is also known as the sine window.
+
+The cosine window is defined as follows:
+
+.. math::
+    w_n = \sin\left(\frac{\pi (n + 0.5)}{M}\right)
+
+This formula differs from the typical cosine window formula by incorporating a 0.5 term in the numerator,
+which shifts the sample positions. This adjustment results in a window that starts and ends with non-zero values.
+
+""",
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric cosine window.
+    >>> torch.signal.windows.cosine(10)
+    tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540, 0.1564])
+
+    >>> # Generates a periodic cosine window.
+    >>> torch.signal.windows.cosine(10, sym=False)
+    tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549, 0.4154])
+""".format(
+        **window_common_args,
+    ),
+)
+def cosine(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("cosine", M, dtype, layout)
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    start = 0.5
+    constant = torch.pi / (M + 1 if not sym and M > 1 else M)
+
+    k = torch.linspace(
+        start=start * constant,
+        end=(start + (M - 1)) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.sin(k)
+
+
+@_add_docstr(
+    r"""
+Computes a window with a gaussian waveform.
+
+The gaussian window is defined as follows:
+
+.. math::
+    w_n = \exp{\left(-\left(\frac{n}{2\sigma}\right)^2\right)}
+    """,
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
+        Default: 1.0.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric gaussian window with a standard deviation of 1.0.
+    >>> torch.signal.windows.gaussian(10)
+    tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
+
+    >>> # Generates a periodic gaussian window and standard deviation equal to 0.9.
+    >>> torch.signal.windows.gaussian(10, sym=False,std=0.9)
+    tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
+""".format(
+        **window_common_args,
+    ),
+)
+def gaussian(
+    M: int,
+    *,
+    std: float = 1.0,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("gaussian", M, dtype, layout)
+
+    if std <= 0:
+        raise ValueError(f"Standard deviation must be positive, got: {std} instead.")
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    start = -(M if not sym and M > 1 else M - 1) / 2.0
+
+    constant = 1 / (std * sqrt(2))
+
+    k = torch.linspace(
+        start=start * constant,
+        end=(start + (M - 1)) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.exp(-(k**2))  # pyrefly: ignore [unsupported-operation]
+
+
+@_add_docstr(
+    r"""
+Computes the Kaiser window.
+
+The Kaiser window is defined as follows:
+
+.. math::
+    w_n = I_0 \left( \beta \sqrt{1 - \left( {\frac{n - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta )
+
+where ``I_0`` is the zeroth order modified Bessel function of the first kind (see :func:`torch.special.i0`), and
+``N = M - 1 if sym else M``.
+    """,
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    beta (float, optional): shape parameter for the window. Must be non-negative. Default: 12.0
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric gaussian window with a standard deviation of 1.0.
+    >>> torch.signal.windows.kaiser(5)
+    tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
+    >>> # Generates a periodic gaussian window and standard deviation equal to 0.9.
+    >>> torch.signal.windows.kaiser(5, sym=False,std=0.9)
+    tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
+""".format(
+        **window_common_args,
+    ),
+)
+def kaiser(
+    M: int,
+    *,
+    beta: float = 12.0,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("kaiser", M, dtype, layout)
+
+    if beta < 0:
+        raise ValueError(f"beta must be non-negative, got: {beta} instead.")
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if M == 1:
+        return torch.ones(
+            (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    # Avoid NaNs by casting `beta` to the appropriate dtype.
+    # pyrefly: ignore [bad-assignment]
+    beta = torch.tensor(beta, dtype=dtype, device=device)
+
+    start = -beta
+    constant = 2.0 * beta / (M if not sym else M - 1)
+    end = torch.minimum(
+        # pyrefly: ignore [bad-argument-type]
+        beta,
+        # pyrefly: ignore [bad-argument-type]
+        start + (M - 1) * constant,
+    )
+
+    k = torch.linspace(
+        start=start,
+        end=end,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
+        # pyrefly: ignore [bad-argument-type]
+        beta
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Hamming window.
+
+The Hamming window is defined as follows:
+
+.. math::
+    w_n = \alpha - \beta\ \cos \left( \frac{2 \pi n}{M - 1} \right)
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    alpha (float, optional): The coefficient :math:`\alpha` in the equation above.
+    beta (float, optional): The coefficient :math:`\beta` in the equation above.
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Hamming window.
+    >>> torch.signal.windows.hamming(10)
+    tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800])
+
+    >>> # Generates a periodic Hamming window.
+    >>> torch.signal.windows.hamming(10, sym=False)
+    tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
+""".format(**window_common_args),
+)
+def hamming(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_hamming(
+        M,
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Hann window.
+
+The Hann window is defined as follows:
+
+.. math::
+    w_n = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{M - 1} \right)\right] =
+    \sin^2 \left( \frac{\pi n}{M - 1} \right)
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Hann window.
+    >>> torch.signal.windows.hann(10)
+    tensor([0.0000, 0.1170, 0.4132, 0.7500, 0.9698, 0.9698, 0.7500, 0.4132, 0.1170, 0.0000])
+
+    >>> # Generates a periodic Hann window.
+    >>> torch.signal.windows.hann(10, sym=False)
+    tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
+""".format(**window_common_args),
+)
+def hann(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_hamming(
+        M,
+        alpha=0.5,
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Blackman window.
+
+The Blackman window is defined as follows:
+
+.. math::
+    w_n = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{M - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{M - 1} \right)
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Blackman window.
+    >>> torch.signal.windows.blackman(5)
+    tensor([-1.4901e-08,  3.4000e-01,  1.0000e+00,  3.4000e-01, -1.4901e-08])
+
+    >>> # Generates a periodic Blackman window.
+    >>> torch.signal.windows.blackman(5, sym=False)
+    tensor([-1.4901e-08,  2.0077e-01,  8.4923e-01,  8.4923e-01,  2.0077e-01])
+""".format(**window_common_args),
+)
+def blackman(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("blackman", M, dtype, layout)
+
+    return general_cosine(
+        M,
+        a=[0.42, 0.5, 0.08],
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Bartlett window.
+
+The Bartlett window is defined as follows:
+
+.. math::
+    w_n = 1 - \left| \frac{2n}{M - 1} - 1 \right| = \begin{cases}
+        \frac{2n}{M - 1} & \text{if } 0 \leq n \leq \frac{M - 1}{2} \\
+        2 - \frac{2n}{M - 1} & \text{if } \frac{M - 1}{2} < n < M \\ \end{cases}
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Bartlett window.
+    >>> torch.signal.windows.bartlett(10)
+    tensor([0.0000, 0.2222, 0.4444, 0.6667, 0.8889, 0.8889, 0.6667, 0.4444, 0.2222, 0.0000])
+
+    >>> # Generates a periodic Bartlett window.
+    >>> torch.signal.windows.bartlett(10, sym=False)
+    tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
+""".format(**window_common_args),
+)
+def bartlett(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("bartlett", M, dtype, layout)
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if M == 1:
+        return torch.ones(
+            (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    start = -1
+    constant = 2 / (M if not sym else M - 1)
+
+    k = torch.linspace(
+        start=start,
+        end=start + (M - 1) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return 1 - torch.abs(k)
+
+
+@_add_docstr(
+    r"""
+Computes the general cosine window.
+
+The general cosine window is defined as follows:
+
+.. math::
+    w_n = \sum^{M-1}_{i=0} (-1)^i a_i \cos{ \left( \frac{2 \pi i n}{M - 1}\right)}
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    a (Iterable): the coefficients associated to each of the cosine functions.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric general cosine window with 3 coefficients.
+    >>> torch.signal.windows.general_cosine(10, a=[0.46, 0.23, 0.31], sym=True)
+    tensor([0.5400, 0.3376, 0.1288, 0.4200, 0.9136, 0.9136, 0.4200, 0.1288, 0.3376, 0.5400])
+
+    >>> # Generates a periodic general cosine window with 2 coefficients.
+    >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
+    tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
+""".format(**window_common_args),
+)
+def general_cosine(
+    M,
+    *,
+    a: Iterable,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("general_cosine", M, dtype, layout)
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if M == 1:
+        return torch.ones(
+            (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if not isinstance(a, Iterable):
+        raise TypeError("Coefficients must be a list/tuple")
+
+    if not a:
+        raise ValueError("Coefficients cannot be empty")
+
+    constant = 2 * torch.pi / (M if not sym else M - 1)
+
+    k = torch.linspace(
+        start=0,
+        end=(M - 1) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    a_i = torch.tensor(
+        [(-1) ** i * w for i, w in enumerate(a)],
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+    )
+    i = torch.arange(
+        a_i.shape[0],
+        dtype=a_i.dtype,
+        device=a_i.device,
+        requires_grad=a_i.requires_grad,
+    )
+    return (a_i.unsqueeze(-1) * torch.cos(i.unsqueeze(-1) * k)).sum(0)
+
+
+@_add_docstr(
+    r"""
+Computes the general Hamming window.
+
+The general Hamming window is defined as follows:
+
+.. math::
+    w_n = \alpha - (1 - \alpha) \cos{ \left( \frac{2 \pi n}{M-1} \right)}
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    alpha (float, optional): the window coefficient. Default: 0.54.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Hamming window with the general Hamming window.
+    >>> torch.signal.windows.general_hamming(10, sym=True)
+    tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800])
+
+    >>> # Generates a periodic Hann window with the general Hamming window.
+    >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
+    tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
+""".format(**window_common_args),
+)
+def general_hamming(
+    M,
+    *,
+    alpha: float = 0.54,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_cosine(
+        M,
+        a=[alpha, 1.0 - alpha],
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the minimum 4-term Blackman-Harris window according to Nuttall.
+
+.. math::
+    w_n = 1 - 0.36358 \cos{(z_n)} + 0.48917 \cos{(2z_n)} - 0.13659 \cos{(3z_n)} + 0.01064 \cos{(4z_n)}
+
+where :math:`z_n = \frac{2 \pi n}{M}`.
+    """,
+    """
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+References::
+
+    - A. Nuttall, "Some windows with very good sidelobe behavior,"
+      IEEE Transactions on Acoustics, Speech, and Signal Processing, vol. 29, no. 1, pp. 84-91,
+      Feb 1981. https://doi.org/10.1109/TASSP.1981.1163506
+
+    - Heinzel G. et al., "Spectrum and spectral density estimation by the Discrete Fourier transform (DFT),
+      including a comprehensive list of window functions and some new flat-top windows",
+      February 15, 2002 https://holometer.fnal.gov/GH_FFT.pdf
+
+Examples::
+
+    >>> # Generates a symmetric Nutall window.
+    >>> torch.signal.windows.general_hamming(5, sym=True)
+    tensor([3.6280e-04, 2.2698e-01, 1.0000e+00, 2.2698e-01, 3.6280e-04])
+
+    >>> # Generates a periodic Nuttall window.
+    >>> torch.signal.windows.general_hamming(5, sym=False)
+    tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
+""".format(**window_common_args),
+)
+def nuttall(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_cosine(
+        M,
+        a=[0.3635819, 0.4891775, 0.1365995, 0.0106411],
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8f3226eda73a55349e2f23d9e69508eebd26596
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4152c4f5c60c00d8bc7dce93d8d431ffb8b4751
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e255c9af7f9b8441333672f766c0935a49e868b9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ef0a65f783911036f28761c2a95f62bbfd7e68c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..accf6f05630a154eee4eda8567b115085e84c88e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3616fede6ce67b70f244419236864a03ffcbb35
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py
@@ -0,0 +1,472 @@
+# mypy: ignore-errors
+
+import collections
+
+import torch
+from torch.testing._internal.common_utils import TEST_WITH_ROCM
+from torch.testing._internal.common_utils import TestCase
+
+
+class AutocastTestLists:
+    def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype):
+        input = (torch.randn((n, n), device=dev, dtype=torch.float32),)
+
+        hx = ((torch.randn((n, n), device=dev, dtype=torch.float32),
+               torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else
+              torch.randn((n, n), device=dev, dtype=torch.float32),)
+
+        weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_ih
+                   torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_hh
+                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32),  # bias_ih
+                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32))  # bias_hh
+
+        # returns args as a tuple
+        return input + hx + weights
+
+    # Supplies ops and arguments for test_autocast_* in test/test_cuda.py
+    def __init__(self, dev):
+        super().__init__()
+        n = 8
+        # Utility arguments, created as one-element tuples
+        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+        mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+        mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+
+        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
+        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
+                           torch.randn(dimset, dtype=torch.float32, device=dev))
+                          for dimset in dimsets]
+        bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
+        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
+        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+
+        # The lists below organize ops that autocast needs to test.
+        # self.list_name corresponds to test_autocast_list_name in test/test_cuda.py.
+        # Each op is associated with a tuple of valid arguments.
+        # In addition, cudnn conv ops are not supported on ROCm and hence will
+        # be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list.
+
+        # Some ops implement built-in type promotion.  These don't need autocasting,
+        # but autocasting relies on their promotion, so we include tests to double-check.
+        self.torch_expect_builtin_promote = [
+            ("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("le", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("add", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("div", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
+            ("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
+        ]
+        self.methods_expect_builtin_promote = [
+            ("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+
+        # The remaining lists organize ops that autocast treats explicitly.
+        self.torch_fp16 = [
+            # deprecated _convolution
+            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
+                                                              (0, 0), 1, False, True, True)),
+            # the current  _convolution
+            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
+                                                              (0, 0), 1, False, True, True, True)),
+            ("conv1d", conv_args_fp32[0]),
+            ("conv2d", conv_args_fp32[1]),
+            ("conv3d", conv_args_fp32[2]),
+            ("conv_tbc", conv_args_fp32[0] + bias_fp32),
+            ("conv_transpose1d", conv_args_fp32[0]),
+            ("conv_transpose2d", conv_args_fp32[1]),
+            ("conv_transpose3d", conv_args_fp32[2]),
+            ("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)),
+            ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM),
+            ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1),
+                                                                 (1, 1), 1, False, True, True), TEST_WITH_ROCM),
+            ("prelu", pointwise0_fp32 + element0_fp32),
+            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
+            ("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32),
+            ("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            ("matmul", mat0_fp32 + mat1_fp32),
+            ("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32),
+            ("mm", mat0_fp32 + mat1_fp32),
+            ("mv", mat0_fp32 + pointwise0_fp32),
+            ("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
+            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            # _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell.
+            # ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            # ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            ("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)),
+            ("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)),
+            ("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
+            ("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
+        ]
+        self.torch_fp32 = [
+            ("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
+            ("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
+            ("cosh", pointwise0_fp16),
+            ("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)),
+            ("exp", pointwise0_fp16),
+            ("expm1", pointwise0_fp16),
+            ("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)),
+            ("reciprocal", pointwise0_fp16),
+            ("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)),
+            ("sinh", pointwise0_fp16),
+            ("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)),
+            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16),
+            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
+            # ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API.
+            ("softmax", pointwise0_fp16 + (0,)),
+            ("log_softmax", pointwise0_fp16 + (0,)),
+            ("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)),
+            ("group_norm", mat0_fp16 + (1,)),
+            ("norm", pointwise0_fp16),
+            ("norm", pointwise0_fp16, {"dim": 0}),
+            # these need magma
+            # ("norm", mat0_fp16, {"p": "nuc"}),
+            # ("norm", mat0_fp16, {"p": "nuc", "dim": 0}),
+            ("norm", pointwise0_fp16, {"p": 1}),
+            ("norm", pointwise0_fp16, {"p": 1, "dim": 0}),
+            ("cosine_similarity", mat0_fp16 + mat1_fp16),
+            ("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
+            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16),
+                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16),
+                                       torch.tensor([1], device=dev, dtype=torch.int))),
+            ("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)),
+            ("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
+            ("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)),
+            ("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16),
+            ("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
+            ("cumprod", pointwise0_fp16 + (0,)),
+            ("cumsum", pointwise0_fp16 + (0,)),
+            ("dist", pointwise0_fp16 + pointwise1_fp16),
+            ("pdist", mat0_fp16),
+            ("cdist", mat0_fp16 + mat1_fp16),
+            ("prod", pointwise0_fp16),
+            ("prod", pointwise0_fp16 + (0,)),
+            ("renorm", mat0_fp16 + (2, 0, 1.0)),
+            ("sum", pointwise0_fp16),
+            ("sum", mat0_fp16 + (1,)),
+            ("logsumexp", mat0_fp16 + (1,)),
+        ]
+        self.torch_need_autocast_promote = [
+            ("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)),
+            ("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16),
+            ("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)),
+            ("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev),
+                          torch.randn((1, 2), dtype=torch.float32, device=dev),
+                          torch.randn((1, 2, 2), dtype=torch.float16, device=dev),
+                          torch.randn((1,), dtype=torch.float32, device=dev))),
+            ("cross", (torch.randn(3, dtype=torch.float32, device=dev),
+                       torch.randn(3, dtype=torch.float16, device=dev))),
+            ("dot", pointwise0_fp16 + pointwise1_fp32),
+            ("vdot", pointwise0_fp16 + pointwise1_fp32),
+            ("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev),
+                              torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev),
+                              0, 0, False)),
+            ("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),),
+                                             torch.randn(1, device=dev, dtype=torch.float16))),
+            ("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),),
+                                             torch.randn(1, device=dev, dtype=torch.float32))),
+            ("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),
+                           torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
+            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev),
+                             0,
+                             torch.randint(0, 2, (2, 2, 2), device=dev),
+                             torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
+            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev),
+                             0,
+                             torch.randint(0, 2, (2, 2, 2), device=dev),
+                             torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
+        ]
+        self.nn_fp16 = [
+            ("linear", mat0_fp32 + mat1_fp32 + mat2_fp32),
+        ]
+        self.nn_fp32 = [
+            ("softplus", pointwise0_fp16),
+            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float),
+                          torch.zeros((n,), device=dev, dtype=torch.long))),
+            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half),
+                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
+            ("l1_loss", mat0_fp16 + mat1_fp16),
+            ("smooth_l1_loss", mat0_fp16 + mat1_fp16),
+            ("mse_loss", mat0_fp16 + mat1_fp16),
+            ("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
+        ]
+        self.linalg_fp16 = [
+            ("linalg_vecdot", mat0_fp32 + mat0_fp32),
+            ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
+        ]
+        self.methods_fp16 = [
+            ("__matmul__", mat0_fp32 + mat1_fp32)
+        ]
+        self.methods_fp32 = [
+            ("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)),
+        ]
+        self.banned = [
+            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32),
+                                      torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
+        ]
+
+
+class AutocastCPUTestLists:
+    # Supplies ops and arguments for test_autocast_* in test/test_cpu.py
+    def __init__(self, dev):
+        super().__init__()
+        n = 8
+        # Utility arguments, created as one-element tuples
+        pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
+        pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
+        mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+        mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+        mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+
+        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+
+        dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
+
+        dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
+                      for dimset in dummy_dimsets]
+
+        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
+        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
+                           torch.randn(dimset, dtype=torch.float32, device=dev))
+                          for dimset in dimsets]
+
+        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
+        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+
+        dummy_fp32 = [  # noqa: F841
+            (torch.randn(dimset, dtype=torch.float32, device=dev),)
+            for dimset in dummy_dimsets
+        ]
+        # The lists below organize ops that autocast needs to test.
+        # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py.
+        # Each op is associated with a tuple of valid arguments.
+
+        # Some ops implement built-in type promotion.  These don't need autocasting,
+        # but autocasting relies on their promotion, so we include tests to double-check.
+        self.torch_expect_builtin_promote = [
+            ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+
+        self.methods_expect_builtin_promote = [
+            ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+        # The remaining lists organize ops that autocast treats explicitly.
+        self.torch_16 = [
+            ("conv1d", conv_args_fp32[0]),
+            ("conv2d", conv_args_fp32[1]),
+            ("conv3d", conv_args_fp32[2]),
+            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("mm", mat0_fp32 + mat1_fp32),
+            ("matmul", mat0_fp32 + mat1_fp32),
+            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
+            ("_addmm_activation", mat1_fp32 + mat2_fp32 + mat3_fp32, {"beta": 1, "alpha": 1, "use_gelu": True}),
+            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
+                          torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
+                          torch.randn(5, device=dev, dtype=torch.float32),
+                          0)),
+            ("conv_transpose1d", conv_args_fp32[0]),
+            ("conv_transpose2d", conv_args_fp32[1]),
+            ("conv_transpose3d", conv_args_fp32[2]),
+            ("prelu", pointwise0_fp32 + element0_fp32),
+            ("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((3 * n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n), device=dev, dtype=torch.float32))),
+        ]
+        self.torch_fp32 = [
+            ("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
+            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16),
+                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16),
+                                       torch.tensor([1], device=dev, dtype=torch.int))),
+            ("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)),
+            ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)),
+            ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
+            ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
+        ]
+        self.nn_16 = [
+            ("linear", mat0_fp32 + mat1_fp32, {}),
+        ]
+        self.nn_fp32 = [
+            ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
+            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
+                                     (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
+            ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
+            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),
+                          torch.zeros((n,), device=dev, dtype=torch.long))),
+            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16),
+                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
+            ("l1_loss", mat0_bf16 + mat1_bf16),
+            ("smooth_l1_loss", mat0_bf16 + mat1_bf16),
+            ("mse_loss", mat0_bf16 + mat1_bf16),
+            ("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
+            ("huber_loss", mat0_bf16 + mat1_bf16),
+        ]
+        self.torch_need_autocast_promote = [
+            ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
+            ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
+        ]
+
+
+class TestAutocast(TestCase):
+    def args_maybe_kwargs(self, op_with_args):
+        if len(op_with_args) == 2:
+            return op_with_args[0], op_with_args[1], {}
+        else:
+            return op_with_args[0], op_with_args[1], op_with_args[2]
+
+    def _run_autocast_outofplace(
+        self,
+        op,
+        args,
+        run_as_type,
+        device,
+        out_type=None,
+        module=torch,
+        add_kwargs=None,
+        amp_dtype=torch.bfloat16,
+    ):
+        # helper to cast args
+        def cast(val, to_type):
+            if isinstance(val, torch.Tensor):
+                return val.to(to_type) if val.is_floating_point() else val
+            elif isinstance(val, collections.abc.Iterable):
+                return type(val)(cast(v, to_type) for v in val)
+            else:
+                return val
+
+        if add_kwargs is None:
+            add_kwargs = {}
+
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))
+        with torch.amp.autocast(device_type=device, dtype=amp_dtype):
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+
+            out_type = out_type if out_type is not None else run_as_type
+            output = output_method = None
+
+            # Try module.* variant, if requested:
+            if module is not None and hasattr(module, op):
+                output = getattr(module, op)(*args, **add_kwargs)
+                if isinstance(output, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output.dtype,
+                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
+                    )
+            # Try Tensor.* variant:
+            if hasattr(torch.Tensor, op):
+                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
+                if isinstance(output_method, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output_method.dtype,
+                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
+                    )
+
+            self.assertTrue(
+                (output is not None) or (output_method is not None),
+                f"{op} not found as an attribute on either Tensor or the requested module {module}",
+            )
+
+            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
+            # For example, lstm_cell returns a tuple and equal returns bool.
+            def compare(first, second):
+                if isinstance(first, torch.Tensor):
+                    return torch.equal(first, second)
+                elif isinstance(first, collections.abc.Iterable):
+                    return all(compare(f, s) for f, s in zip(first, second, strict=False))
+                else:
+                    return first == second
+
+            # If both torch.* and Tensor.* variants were found, check outputs are identical
+            if (output is not None) and (output_method is not None):
+                self.assertTrue(type(output) is type(output_method))
+                comparison = compare(output, output_method)
+                self.assertTrue(
+                    comparison, f"torch.{op} result did not match Tensor.{op} result"
+                )
+
+            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
+            # as the C++-side autocasting, and should be bitwise accurate.
+            output_to_compare = output if output is not None else output_method
+            with torch.amp.autocast(device_type=device, enabled=False):
+                self.assertFalse(
+                    torch.is_autocast_enabled(device_type=device)
+                )
+
+                if module is not None and hasattr(module, op):
+                    control = getattr(module, op)(
+                        *cast(args, run_as_type), **add_kwargs
+                    )
+                else:
+                    control = getattr(args[0].to(run_as_type), op)(
+                        *cast(args[1:], run_as_type), **add_kwargs
+                    )
+                self.assertTrue(type(output_to_compare) is type(control))
+                comparison = compare(output_to_compare, control)
+                self.assertTrue(comparison, f"torch.{op} result did not match control")
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2219ef4ea56aa306dfdd3af18b7403af8384c78
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py
@@ -0,0 +1,164 @@
+# mypy: ignore-errors
+
+import os
+import re
+import sys
+
+__all__ = [
+    "check_code_for_cuda_kernel_launches",
+    "check_cuda_kernel_launches",
+]
+
+# FILES TO EXCLUDE (match is done with suffix using `endswith`)
+# You wouldn't drive without a seatbelt, though, so why would you
+# launch a kernel without some safety? Use this as a quick workaround
+# for a problem with the checker, fix the checker, then de-exclude
+# the files in question.
+exclude_files: list[str] = []
+
+# Without using a C++ AST we can't 100% detect kernel launches, so we
+# model them as having the pattern "<<>>(arguments);"
+# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
+# the next statement.
+#
+# We model the next statement as ending at the next `}` or `;`.
+# If we see `}` then a clause ended (bad) if we see a semi-colon then
+# we expect the launch check just before it.
+#
+# Since the kernel launch can include lambda statements, it's important
+# to find the correct end-paren of the kernel launch. Doing this with
+# pure regex requires recursive regex, which aren't part of the Python
+# standard library. To avoid an additional dependency, we build a prefix
+# regex that finds the start of a kernel launch, use a paren-matching
+# algorithm to find the end of the launch, and then another regex to
+# determine if a launch check is present.
+
+# Finds potential starts of kernel launches
+kernel_launch_start = re.compile(
+    r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
+)
+
+# This pattern should start at the character after the final paren of the
+# kernel launch. It returns a match if the launch check is not the next statement
+has_check = re.compile(
+    r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
+)
+
+def find_matching_paren(s: str, startpos: int) -> int:
+    """Given a string "prefix (unknown number of characters) suffix"
+    and the position of the first `(` returns the index of the character
+    1 past the `)`, accounting for paren nesting
+    """
+    opening = 0
+    for i, c in enumerate(s[startpos:]):
+        if c == '(':
+            opening += 1
+        elif c == ')':
+            opening -= 1
+            if opening == 0:
+                return startpos + i + 1
+
+    raise IndexError("Closing parens not found!")
+
+
+def should_exclude_file(filename) -> bool:
+    for exclude_suffix in exclude_files:
+        if filename.endswith(exclude_suffix):
+            return True
+    return False
+
+
+def check_code_for_cuda_kernel_launches(code, filename=None):
+    """Checks code for CUDA kernel launches without cuda error checks.
+
+    Args:
+        filename - Filename of file containing the code. Used only for display
+                   purposes, so you can put anything here.
+        code     - The code to check
+
+    Returns:
+        The number of unsafe kernel launches in the code
+    """
+    if filename is None:
+        filename = "##Python Function Call##"
+
+    # We break the code apart and put it back together to add
+    # helpful line numberings for identifying problem areas
+    code = enumerate(code.split("\n"))                             # Split by line breaks
+    code = [f"{lineno}: {linecode}" for lineno, linecode in code]  # Number the lines
+    code = '\n'.join(code)                                         # Put it back together
+
+    num_launches_without_checks = 0
+    for m in kernel_launch_start.finditer(code):
+        end_paren = find_matching_paren(code, m.end() - 1)
+        if has_check.match(code, end_paren):
+            num_launches_without_checks += 1
+            context = code[m.start():end_paren + 1]
+            print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
+
+    return num_launches_without_checks
+
+
+def check_file(filename):
+    """Checks a file for CUDA kernel launches without cuda error checks
+
+    Args:
+        filename - File to check
+
+    Returns:
+        The number of unsafe kernel launches in the file
+    """
+    if not (filename.endswith((".cu", ".cuh"))):
+        return 0
+    if should_exclude_file(filename):
+        return 0
+    with open(filename) as f:
+        contents = f.read()
+        unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
+    return unsafeCount
+
+
+def check_cuda_kernel_launches():
+    """Checks all pytorch code for CUDA kernel launches without cuda error checks
+
+    Returns:
+        The number of unsafe kernel launches in the codebase
+    """
+    torch_dir = os.path.dirname(os.path.realpath(__file__))
+    torch_dir = os.path.dirname(torch_dir)  # Go up to parent torch
+    torch_dir = os.path.dirname(torch_dir)  # Go up to parent caffe2
+
+    kernels_without_checks = 0
+    files_without_checks = []
+    for root, dirnames, filenames in os.walk(torch_dir):
+        # `$BASE/build` and `$BASE/torch/include` are generated
+        # so we don't want to flag their contents
+        if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
+            # Curtail search by modifying dirnames and filenames in place
+            # Yes, this is the way to do this, see `help(os.walk)`
+            dirnames[:] = []
+            continue
+
+        for x in filenames:
+            filename = os.path.join(root, x)
+            file_result = check_file(filename)
+            if file_result > 0:
+                kernels_without_checks += file_result
+                files_without_checks.append(filename)
+
+    if kernels_without_checks > 0:
+        count_str = f"Found {kernels_without_checks} instances in " \
+                    f"{len(files_without_checks)} files where kernel " \
+                    "launches didn't have checks."
+        print(count_str, file=sys.stderr)
+        print("Files without checks:", file=sys.stderr)
+        for x in files_without_checks:
+            print(f"\t{x}", file=sys.stderr)
+        print(count_str, file=sys.stderr)
+
+    return kernels_without_checks
+
+
+if __name__ == "__main__":
+    unsafe_launches = check_cuda_kernel_launches()
+    sys.exit(0 if unsafe_launches == 0 else 1)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5d0cf2992110f4cc0120f58fb7ba39f8f87947
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py
@@ -0,0 +1,387 @@
+# mypy: ignore-errors
+
+r"""This file is allowed to initialize CUDA context when imported."""
+
+import functools
+import torch
+import torch.cuda
+from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS, IS_MACOS
+import inspect
+import contextlib
+import os
+import unittest
+
+
+CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
+
+
+TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
+CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
+# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
+if TEST_WITH_ROCM:
+    TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
+else:
+    TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
+
+TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
+ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0))
+
+SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
+SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
+SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
+SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
+SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
+SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
+SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
+SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
+SM120OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (12, 0))
+
+IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
+                  and torch.cuda.get_device_capability()[1] > 0)
+IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
+IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
+IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0))
+IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0))
+
+def evaluate_gfx_arch_within(arch_list):
+    if not torch.cuda.is_available():
+        return False
+    gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
+    effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
+    # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
+    # Hence the matching should be done reversely
+    return any(arch in effective_arch for arch in arch_list)
+
+def CDNA3OrLater():
+    return evaluate_gfx_arch_within(["gfx940", "gfx941", "gfx942", "gfx950"])
+
+def CDNA2OrLater():
+    return evaluate_gfx_arch_within(["gfx90a", "gfx942"])
+
+def evaluate_platform_supports_flash_attention():
+    if TEST_WITH_ROCM:
+        arch_list = ["gfx90a", "gfx942", "gfx1201", "gfx950"]
+        if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
+            arch_list += ["gfx1100", "gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"]
+        return evaluate_gfx_arch_within(arch_list)
+    if TEST_CUDA:
+        return not IS_WINDOWS and SM80OrLater
+    return False
+
+def evaluate_platform_supports_efficient_attention():
+    if TEST_WITH_ROCM:
+        arch_list = ["gfx90a", "gfx942", "gfx1201", "gfx950"]
+        if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
+            arch_list += ["gfx1100", "gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"]
+        return evaluate_gfx_arch_within(arch_list)
+    if TEST_CUDA:
+        return True
+    return False
+
+def evaluate_platform_supports_cudnn_attention():
+    return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
+
+def evaluate_platform_supports_green_context():
+    if IS_WINDOWS:
+        return False
+    if not _get_torch_cuda_version() >= (12, 8):
+        return False
+    driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run)
+    if driver_version is None:
+        return False
+    return int(driver_version.split('.')[0]) >= 570
+
+PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
+PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
+PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
+# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
+PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or
+                                                  PLATFORM_SUPPORTS_CUDNN_ATTENTION or
+                                                  PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
+
+PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
+
+PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
+
+PLATFORM_SUPPORTS_GREEN_CONTEXT: bool = LazyVal(lambda: evaluate_platform_supports_green_context())
+
+def evaluate_platform_supports_fp8():
+    if torch.cuda.is_available():
+        if torch.version.hip:
+            archs = ['gfx94']
+            if ROCM_VERSION >= (6, 3):
+                archs.extend(['gfx120'])
+            if ROCM_VERSION >= (6, 5):
+                archs.append('gfx95')
+            for arch in archs:
+                if arch in torch.cuda.get_device_properties(0).gcnArchName:
+                    return True
+        else:
+            return SM90OrLater or torch.cuda.get_device_capability() == (8, 9)
+    return False
+
+def evaluate_platform_supports_fp8_grouped_gemm():
+    if torch.cuda.is_available():
+        if torch.version.hip:
+            if "USE_FBGEMM_GENAI" not in torch.__config__.show():
+                return False
+            archs = ['gfx942']
+            for arch in archs:
+                if arch in torch.cuda.get_device_properties(0).gcnArchName:
+                    return True
+        else:
+            return SM90OrLater and not SM100OrLater
+    return False
+
+def evaluate_platform_supports_mx_gemm():
+    if torch.cuda.is_available():
+        if torch.version.hip:
+            if ROCM_VERSION >= (7, 0):
+                return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName
+        else:
+            return SM100OrLater
+    return False
+
+def evaluate_platform_supports_mxfp8_grouped_gemm():
+    if torch.cuda.is_available() and not torch.version.hip:
+        built_with_fbgemm_genai = "USE_FBGEMM_GENAI" in torch.__config__.show()
+        return built_with_fbgemm_genai and IS_SM100
+    return False
+
+PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm())
+PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
+PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
+PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm())
+
+if TEST_NUMBA:
+    try:
+        import numba.cuda
+        TEST_NUMBA_CUDA = numba.cuda.is_available()
+    except Exception:
+        TEST_NUMBA_CUDA = False
+        TEST_NUMBA = False
+else:
+    TEST_NUMBA_CUDA = False
+
+# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
+# RNG have been initialized.
+__cuda_ctx_rng_initialized = False
+
+
+# after this call, CUDA context and RNG must have been initialized on each GPU
+def initialize_cuda_context_rng():
+    global __cuda_ctx_rng_initialized
+    assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
+    if not __cuda_ctx_rng_initialized:
+        # initialize cuda context and rng for memory tests
+        for i in range(torch.cuda.device_count()):
+            torch.randn(1, device=f"cuda:{i}")
+        __cuda_ctx_rng_initialized = True
+
+
+@contextlib.contextmanager
+def tf32_off():
+    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
+    try:
+        torch.backends.cuda.matmul.allow_tf32 = False
+        with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
+            yield
+    finally:
+        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
+
+
+@contextlib.contextmanager
+def tf32_on(self, tf32_precision=1e-5):
+    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
+    old_precision = self.precision
+    try:
+        torch.backends.cuda.matmul.allow_tf32 = True
+        self.precision = tf32_precision
+        with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
+            yield
+    finally:
+        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
+        self.precision = old_precision
+
+
+@contextlib.contextmanager
+def tf32_enabled():
+    """
+    Context manager to temporarily enable TF32 for CUDA operations.
+    Restores the previous TF32 state after exiting the context.
+    """
+    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
+    try:
+        torch.backends.cuda.matmul.allow_tf32 = True
+        with torch.backends.cudnn.flags(
+            enabled=None, benchmark=None, deterministic=None, allow_tf32=True
+        ):
+            yield
+    finally:
+        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
+
+
+# This is a wrapper that wraps a test to run this test twice, one with
+# allow_tf32=True, another with allow_tf32=False. When running with
+# allow_tf32=True, it will use reduced precision as specified by the
+# argument. For example:
+#    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
+#    @tf32_on_and_off(0.005)
+#    def test_matmul(self, device, dtype):
+#        a = ...; b = ...;
+#        c = torch.matmul(a, b)
+#        self.assertEqual(c, expected)
+# In the above example, when testing torch.float32 and torch.complex64 on CUDA
+# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
+# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
+# precision to check values.
+#
+# This decorator can be used for function with or without device/dtype, such as
+# @tf32_on_and_off(0.005)
+# def test_my_op(self)
+# @tf32_on_and_off(0.005)
+# def test_my_op(self, device)
+# @tf32_on_and_off(0.005)
+# def test_my_op(self, device, dtype)
+# @tf32_on_and_off(0.005)
+# def test_my_op(self, dtype)
+# if neither device nor dtype is specified, it will check if the system has ampere device
+# if device is specified, it will check if device is cuda
+# if dtype is specified, it will check if dtype is float32 or complex64
+# tf32 and fp32 are different only when all the three checks pass
+def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True):
+    def with_tf32_disabled(self, function_call):
+        with tf32_off():
+            function_call()
+
+    def with_tf32_enabled(self, function_call):
+        with tf32_on(self, tf32_precision):
+            function_call()
+
+    def wrapper(f):
+        params = inspect.signature(f).parameters
+        arg_names = tuple(params.keys())
+
+        @functools.wraps(f)
+        def wrapped(*args, **kwargs):
+            kwargs.update(zip(arg_names, args, strict=False))
+            cond = torch.cuda.is_tf32_supported() and only_if
+            if 'device' in kwargs:
+                cond = cond and (torch.device(kwargs['device']).type == 'cuda')
+            if 'dtype' in kwargs:
+                cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
+            if cond:
+                with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
+                with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
+            else:
+                f(**kwargs)
+
+        return wrapped
+    return wrapper
+
+# This is a wrapper that wraps a test to run it with TF32 turned off.
+# This wrapper is designed to be used when a test uses matmul or convolutions
+# but the purpose of that test is not testing matmul or convolutions.
+# Disabling TF32 will enforce torch.float tensors to be always computed
+# at full precision.
+def with_tf32_off(f):
+    @functools.wraps(f)
+    def wrapped(*args, **kwargs):
+        with tf32_off():
+            return f(*args, **kwargs)
+
+    return wrapped
+
+def _get_magma_version():
+    if 'Magma' not in torch.__config__.show():
+        return (0, 0)
+    position = torch.__config__.show().find('Magma ')
+    version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
+    return tuple(int(x) for x in version_str.split("."))
+
+def _get_torch_cuda_version():
+    if torch.version.cuda is None:
+        return (0, 0)
+    cuda_version = str(torch.version.cuda)
+    return tuple(int(x) for x in cuda_version.split("."))
+
+def _get_torch_rocm_version():
+    if not TEST_WITH_ROCM or torch.version.hip is None:
+        return (0, 0)
+    rocm_version = str(torch.version.hip)
+    rocm_version = rocm_version.split("-", maxsplit=1)[0]    # ignore git sha
+    return tuple(int(x) for x in rocm_version.split("."))
+
+def _check_cusparse_generic_available():
+    return not TEST_WITH_ROCM
+
+def _check_hipsparse_generic_available():
+    if not TEST_WITH_ROCM:
+        return False
+    if not torch.version.hip:
+        return False
+
+    rocm_version = str(torch.version.hip)
+    rocm_version = rocm_version.split("-", maxsplit=1)[0]    # ignore git sha
+    rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
+    return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
+
+
+TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
+TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
+
+# Shared by test_torch.py and test_multigpu.py
+def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
+    # Create a module+optimizer that will use scaling, and a control module+optimizer
+    # that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
+    mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
+    mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
+    with torch.no_grad():
+        for c, s in zip(mod_control.parameters(), mod_scaling.parameters(), strict=True):
+            s.copy_(c)
+
+    kwargs = {"lr": 1.0}
+    if optimizer_kwargs is not None:
+        kwargs.update(optimizer_kwargs)
+    opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
+    opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
+
+    return mod_control, mod_scaling, opt_control, opt_scaling
+
+# Shared by test_torch.py, test_cuda.py and test_multigpu.py
+def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
+    data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
+            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
+            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
+            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
+
+    loss_fn = torch.nn.MSELoss().to(device)
+
+    skip_iter = 2
+
+    return _create_scaling_models_optimizers(
+        device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
+    ) + (data, loss_fn, skip_iter)
+
+
+def xfailIfSM89(func):
+    return func if not IS_SM89 else unittest.expectedFailure(func)
+
+def xfailIfSM89PreCUDA13(func):
+    """xfail on SM89 only for CUDA < 13. On CUDA 13+, test should pass on all architectures."""
+    if IS_SM89 and _get_torch_cuda_version() < (13, 0):
+        return unittest.expectedFailure(func)
+    return func
+
+def xfailIfSM100OrLater(func):
+    return func if not SM100OrLater else unittest.expectedFailure(func)
+
+def xfailIfSM120OrLater(func):
+    return func if not SM120OrLater else unittest.expectedFailure(func)
+
+def xfailIfDistributedNotSupported(func):
+    return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func)
+
+# Importing this module should NOT eagerly initialize CUDA
+if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
+    assert not torch.cuda.is_initialized()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd14b85a21915ddf8ab415f3bf5dc6e79db14dfc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py
@@ -0,0 +1,112 @@
+# mypy: ignore-errors
+
+# Owner(s): ["oncall: distributed"]
+
+
+import torch
+import torch.nn as nn
+
+
+class UnitModule(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l1 = nn.Linear(100, 100, device=device)
+        self.seq = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(100, 100, device=device),
+            nn.ReLU(),
+        )
+        self.l2 = nn.Linear(100, 100, device=device)
+
+    def forward(self, x):
+        return self.l2(self.seq(self.l1(x)))
+
+
+class CompositeModel(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l1 = nn.Linear(100, 100, device=device)
+        self.u1 = UnitModule(device)
+        self.u2 = UnitModule(device)
+        self.l2 = nn.Linear(100, 100, device=device)
+
+    def forward(self, x):
+        return self.l2(self.u2(self.u1(self.l1(x))))
+
+
+class UnitParamModule(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l = nn.Linear(100, 100, device=device)
+        self.seq = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(100, 100, device=device),
+            nn.ReLU(),
+        )
+        self.p = nn.Parameter(torch.randn((100, 100), device=device))
+
+    def forward(self, x):
+        return torch.mm(self.seq(self.l(x)), self.p)
+
+
+class CompositeParamModel(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l = nn.Linear(100, 100, device=device)
+        self.u1 = UnitModule(device)
+        self.u2 = UnitModule(device)
+        self.p = nn.Parameter(torch.randn((100, 100), device=device))
+        self.register_buffer(
+            "buffer", torch.randn((100, 100), device=device), persistent=True
+        )
+
+    def forward(self, x):
+        a = self.u2(self.u1(self.l(x)))
+        b = self.p
+        return torch.mm(a, b)
+
+
+class FakeSequential(nn.Module):
+    # Define this class to achieve a desired nested wrapping using the module
+    # wrap policy with `nn.Sequential`
+    def __init__(self, *modules: tuple[nn.Module, ...]) -> None:
+        super().__init__()
+        self._module_sequence = list(modules)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        for module in self._module_sequence:
+            x = module(x)
+        return x
+
+
+class NestedSequentialModel(nn.Module):
+    def __init__(self, device: torch.device) -> None:
+        super().__init__()
+        # This nested structure exercises traversal order to catch differences
+        # between valid traversals (e.g. BFS and DFS variations).
+        self.seq1 = nn.Sequential(
+            nn.Linear(1, 1, device=device),
+            FakeSequential(
+                nn.Linear(1, 1, device=device),
+                nn.ReLU(),
+                FakeSequential(
+                    nn.Linear(1, 1, device=device),
+                ),
+                nn.ReLU(),
+            ),
+            nn.Linear(1, 2, device=device),
+        )
+        self.lin = nn.Linear(2, 2, device=device)
+        self.seq2 = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(2, 3, device=device),
+            FakeSequential(
+                nn.Linear(3, 2, bias=False, device=device),
+                nn.Linear(2, 4, bias=False, device=device),
+            ),
+        )
+
+        # FIXME(rec): forward() is not a method, it's a local function inside __init__
+        # that is never used. It should probabkly be outdented by four spaces, or removed.
+        def forward(self, x: torch.Tensor) -> torch.Tensor:
+            return self.seq2(self.lin(self.seq1(x)))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df79fa00f81b92492fcd6f23a99f595695b8421
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py
@@ -0,0 +1,1958 @@
+# mypy: ignore-errors
+
+import faulthandler
+import functools
+import itertools
+import logging
+import multiprocessing
+import operator
+import os
+import queue
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import traceback
+import types
+import unittest
+from collections.abc import Callable
+from contextlib import contextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from enum import Enum
+from functools import partial, reduce, wraps
+from io import StringIO
+from typing import Any, NamedTuple, Optional, Union
+from unittest.mock import patch
+
+import torch
+import torch._dynamo.test_case
+import torch.cuda.nccl
+import torch.distributed as c10d
+import torch.nn as nn
+from torch._C._autograd import DeviceType
+from torch._C._distributed_c10d import _SymmetricMemory
+from torch._logging._internal import trace_log
+from torch.testing._internal import common_utils
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    find_free_port,
+    IS_SANDCASTLE,
+    LazyVal,
+    retry_on_connect_failures,
+    skip_but_pass_in_sandcastle,
+    skip_but_pass_in_sandcastle_if,
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_WITH_ROCM,
+    TEST_WITH_TSAN,
+    TEST_XPU,
+    TestCase,
+)
+from torch.testing._internal.distributed.multi_threaded_pg import (
+    _install_threaded_pg,
+    _uninstall_threaded_pg,
+    ProcessLocalGroup,
+)
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
+DDP_RANK_DEVICES = ["cuda", "xpu"]
+HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
+
+
+class TestSkip(NamedTuple):
+    exit_code: int
+    message: str
+
+
+TEST_SKIPS = {
+    "backend_unavailable": TestSkip(
+        72, "Skipped because distributed backend is not available."
+    ),
+    "small_worldsize": TestSkip(73, "Skipped due to small world size."),
+    "odd_worldsize": TestSkip(87, "Skipped due to odd world size."),
+    "no_cuda": TestSkip(74, "CUDA is not available."),
+    "multi-gpu-1": TestSkip(75, "Need at least 1 CUDA device"),
+    "multi-gpu-2": TestSkip(77, "Need at least 2 CUDA devices"),
+    "multi-gpu-3": TestSkip(80, "Need at least 3 CUDA devices"),
+    "multi-gpu-4": TestSkip(81, "Need at least 4 CUDA devices"),
+    "multi-gpu-5": TestSkip(82, "Need at least 5 CUDA devices"),
+    "multi-gpu-6": TestSkip(83, "Need at least 6 CUDA devices"),
+    "multi-gpu-7": TestSkip(84, "Need at least 7 CUDA devices"),
+    "multi-gpu-8": TestSkip(85, "Need at least 8 CUDA devices"),
+    "nccl": TestSkip(76, "c10d not compiled with NCCL support"),
+    "skipIfRocm": TestSkip(78, "Test skipped for ROCm"),
+    "no_peer_access": TestSkip(79, "Test skipped because no GPU peer access"),
+    "generic": TestSkip(
+        86, "Test skipped at subprocess level, look at subprocess log for skip reason"
+    ),
+    "importerror": TestSkip(88, "Test skipped due to missing import"),
+    "no_accelerator": TestSkip(89, "accelerator is not available."),
+}
+
+
+@dataclass
+class DistTestCases:
+    # Backends that do not support a specific collective
+    skip_collective = {}
+    skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc", "xccl"}
+    skip_collective["reduce"] = set()
+    skip_collective["sendrecv anysource"] = {"nccl", "ucc", "xccl"}
+    skip_collective["cpu barrier"] = {"nccl", "ucc", "xccl"}
+
+    # Sets showing that something is implemented
+    backend_feature = {}
+    backend_feature["gpu"] = {"nccl", "gloo", "ucc"}
+    backend_feature["cuda"] = {"nccl", "gloo", "ucc"}
+    backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
+    backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
+    backend_feature["plugin"] = set()
+    if TEST_HPU:
+        backend_feature["hpu"] = {"hccl"}
+    if TEST_XPU:
+        backend_feature["xpu"] = {"xccl"}
+
+
+def requires_ddp_rank(device):
+    return device in DDP_RANK_DEVICES
+
+
+def skip_if_no_gpu(func):
+    """Skips if the world size exceeds the number of GPUs, ensuring that if the
+    test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if not (TEST_CUDA or TEST_HPU or TEST_XPU):
+            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
+        world_size = int(os.environ["WORLD_SIZE"])
+        if TEST_CUDA and torch.cuda.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+        if TEST_HPU and torch.hpu.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+        if TEST_XPU and torch.xpu.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+# TODO (kwen2501): what is the purpose of this decorator?  Tests with this
+# decorator were always skipped. So they may be outdated already.
+# Oct 2024: bumping the small-world criteria to < 8, as we are increasing the
+# number of GPUs in CI from 2 to 4, and we need to continue skipping those tests
+# to keep CI green. But this is just a temporary solution. We should clean up
+# those tests somehow.
+def skip_if_small_worldsize(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) < 8:
+            sys.exit(TEST_SKIPS["small_worldsize"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+def skip_if_odd_worldsize(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) % 2 == 1:
+            sys.exit(TEST_SKIPS["odd_worldsize"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+def require_n_gpus_for_nccl_backend(n, backend):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if backend == "nccl" and torch.cuda.device_count() < n:
+                sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code)
+            else:
+                return func(*args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+def import_transformers_or_skip():
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            try:
+                from transformers import AutoModelForMaskedLM, BertConfig  # noqa: F401
+
+                return func(*args, **kwargs)
+            except ImportError:
+                sys.exit(TEST_SKIPS["importerror"].exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def at_least_x_gpu(x):
+    if TEST_CUDA and torch.cuda.device_count() >= x:
+        return True
+    if TEST_HPU and torch.hpu.device_count() >= x:
+        return True
+    if TEST_XPU and torch.xpu.device_count() >= x:
+        return True
+    return False
+
+
+def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool:
+    _handle_test_skip = getattr(args[0], "_handle_test_skip", None)
+    if len(args) == 0 or _handle_test_skip is None:
+        return False
+    _handle_test_skip(msg)
+    return True
+
+
+def skip_if_lt_x_gpu(x):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
+                return func(*args, **kwargs)
+            if TEST_HPU and torch.hpu.device_count() >= x:
+                return func(*args, **kwargs)
+            if TEST_XPU and torch.xpu.device_count() >= x:
+                return func(*args, **kwargs)
+            test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
+            if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
+                sys.exit(test_skip.exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def requires_world_size(n: int):
+    """
+    Decorator to request a specific world size for a test. The test harness can
+    read this attribute to set the number of ranks to spawn. If there are fewer
+    than `n` CUDA devices available, the test should be skipped by the harness.
+
+    Usage:
+        @require_world_size(3)
+        def test_something(self):
+            ...
+    """
+
+    def decorator(func):
+        func._required_world_size = n
+        available = torch.cuda.device_count()
+        return unittest.skipUnless(
+            available >= n, f"requires {n} GPUs, found {available}"
+        )(func)
+
+    return decorator
+
+
+def get_required_world_size(obj: Any, default: int) -> int:
+    """
+    Returns the requested world size for the currently running unittest method on `obj`
+    if annotated via `@require_world_size(n)`, else returns `default`.
+    """
+    try:
+        # Try MultiProcessTestCase helper first, then unittest fallback
+        test_name = (
+            obj._current_test_name()  # type: ignore[attr-defined]
+            if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
+            else obj._testMethodName
+        )
+        fn = getattr(obj, test_name)
+        value = fn._required_world_size
+        return int(value)
+    except Exception:
+        return default
+
+
+# This decorator helps avoiding initializing cuda while testing other backends
+def nccl_skip_if_lt_x_gpu(backend, x):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if backend != "nccl":
+                return func(*args, **kwargs)
+            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
+                return func(*args, **kwargs)
+            test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
+            if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
+                sys.exit(test_skip.exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def verify_ddp_error_logged(model_DDP, err_substr):
+    # Verify error was logged in ddp_logging_data.
+    ddp_logging_data = model_DDP._get_ddp_logging_data()
+    assert "iteration" in ddp_logging_data
+    assert "has_error" in ddp_logging_data
+    assert "error" in ddp_logging_data
+    logging_err = ddp_logging_data["error"]
+    # Remove C++ stacktrace if needed.
+    actual = (
+        err_substr
+        if err_substr.find("\nException raised from ") == -1
+        else err_substr.split("\nException raised from ")[0]
+    )
+    assert actual in logging_err, (
+        f"Did not find expected {actual} in ddp logging data error: {logging_err}"
+    )
+
+
+def with_nccl_blocking_wait(func):
+    """
+    Convenience decorator to set/unset TORCH_NCCL_BLOCKING_WAIT flag. Note that use of
+    this decorator will override the setting of TORCH_NCCL_ASYNC_ERROR_HANDLING for
+    the particular test. After the test, both TORCH_NCCL_BLOCKING_WAIT and
+    TORCH_NCCL_ASYNC_ERROR_HANDLING will be restored to their original values.
+    """
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        # Save and unset TORCH_NCCL_ASYNC_ERROR_HANDLING
+        try:
+            cached_nccl_async_error_handling: Union[str, None] = os.environ[
+                "TORCH_NCCL_ASYNC_ERROR_HANDLING"
+            ]
+            del os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]
+        except KeyError:
+            # TORCH_NCCL_ASYNC_ERROR_HANDLING was unset
+            cached_nccl_async_error_handling = None
+
+        # Save val of TORCH_NCCL_BLOCKING_WAIT and set it.
+        try:
+            cached_nccl_blocking_wait: Union[str, None] = os.environ[
+                "TORCH_NCCL_BLOCKING_WAIT"
+            ]
+        except KeyError:
+            cached_nccl_blocking_wait = None
+        finally:
+            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
+
+        try:
+            ret = func(*args, **kwargs)
+            return ret
+        finally:
+            # restore old values.
+            if cached_nccl_async_error_handling is not None:
+                os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
+                    cached_nccl_async_error_handling
+                )
+
+            if cached_nccl_blocking_wait is not None:
+                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
+
+    return wrapper
+
+
+def with_dist_debug_levels(levels):
+    """
+    Runs a test for each distributed debug level specified in levels.
+    """
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            old_level = os.environ.get("TORCH_DISTRIBUTED_DEBUG", None)
+            for level in levels:
+                os.environ["TORCH_DISTRIBUTED_DEBUG"] = level
+                c10d.set_debug_level_from_env()
+                ret = func(*args, **kwargs)
+                c10d.barrier()
+                if old_level is not None:
+                    os.environ["TORCH_DISTRIBUTED_DEBUG"] = old_level
+            # Only returns test return for last test, but since these are
+            # unittests the return value is not really used and earlier tests
+            # would've raised had they failed.
+            return ret
+
+        return wrapper
+
+    return decorator
+
+
+def requires_gloo():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_gloo_available(),
+        "c10d was not compiled with the Gloo backend",
+    )
+
+
+def requires_nccl_version(version, msg):
+    if not TEST_CUDA:
+        return lambda f: f
+    if not c10d.is_nccl_available():
+        return skip_but_pass_in_sandcastle(
+            "c10d was not compiled with the NCCL backend",
+        )
+    else:
+        return skip_but_pass_in_sandcastle_if(
+            torch.cuda.nccl.version() < version,
+            f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
+        )
+
+
+def requires_nccl_shrink():
+    """
+    Require NCCL shrink support (NCCL available and version >= 2.27).
+    """
+    return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
+
+
+def requires_nccl():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_nccl_available(),
+        "c10d was not compiled with the NCCL backend",
+    )
+
+
+def requires_ucc():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_ucc_available(),
+        "c10d was not compiled with the UCC backend",
+    )
+
+
+def requires_mpi():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_mpi_available(),
+        "c10d was not compiled with the MPI backend",
+    )
+
+
+def requires_accelerator_dist_backend(backends=None):
+    """
+    Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
+
+    Args:
+        backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
+                                       If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
+
+    Returns:
+        callable: A decorator that skips the test if no specified accelerator backend is available.
+    """
+    if backends is None:
+        backends = ACCELERATOR_DIST_BACKENDS
+
+    backend_available = any(
+        {
+            "nccl": c10d.is_nccl_available,
+            "xccl": c10d.is_xccl_available,
+            "hccl": lambda: TEST_HPU,
+        }.get(backend, lambda: False)()
+        for backend in backends
+    )
+
+    return skip_but_pass_in_sandcastle_if(
+        not backend_available,
+        f"No accelerator communication backend available among {backends}",
+    )
+
+
+def requires_multicast_support():
+    has_multicast_support = (
+        torch.cuda.is_available()
+        and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
+    )
+    return skip_but_pass_in_sandcastle_if(
+        not has_multicast_support,
+        "multicast support is not available",
+    )
+
+
+def evaluate_platform_supports_symm_mem():
+    if TEST_CUDA:
+        if TEST_WITH_ROCM:
+            arch_list = ["gfx942", "gfx950"]
+            for arch in arch_list:
+                if arch in torch.cuda.get_device_properties(0).gcnArchName:
+                    return True
+            return False
+        else:
+            return True
+    else:
+        return False
+
+
+PLATFORM_SUPPORTS_SYMM_MEM: bool = LazyVal(
+    lambda: evaluate_platform_supports_symm_mem()
+)
+
+
+def skip_if_rocm_multiprocess(func):
+    """Skips a test for ROCm multiprocess UTs"""
+    return unittest.skipIf(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func)
+
+
+def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
+    """Skips a test for given ROCm archs - multiprocess UTs"""
+
+    def decorator(func):
+        reason = None
+        if TEST_WITH_ROCM:
+            prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
+            if prop in arch:
+                reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
+
+        return unittest.skipIf(reason is not None, reason)(func)
+
+    return decorator
+
+
+def skip_if_rocm_ver_lessthan_multiprocess(version=None):
+    """Skips a test for ROCm based on ROCm ver - multiprocess UTs"""
+
+    def decorator(func):
+        reason = None
+        if TEST_WITH_ROCM:
+            rocm_version = str(torch.version.hip)
+            rocm_version = rocm_version.split("-", maxsplit=1)[0]  # ignore git sha
+            rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
+            if (
+                rocm_version_tuple is None
+                or version is None
+                or rocm_version_tuple < tuple(version)
+            ):
+                reason = f"skip_if_rocm_ver_lessthan_multiprocess: ROCm {rocm_version_tuple} is available but {version} required"
+
+        return unittest.skipIf(reason is not None, reason)(func)
+
+    return decorator
+
+
+def skip_if_win32():
+    return skip_but_pass_in_sandcastle_if(
+        sys.platform == "win32",
+        "This unit test case is not supported on Windows platform",
+    )
+
+
+def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
+    """
+    Returns True if the device's compute capability is (major, minor) or higher.
+    Error out if the device is not a CUDA device.
+    Returns False if device is a RoCM device.
+    Returns True if device is a non-CUDA device.
+    """
+    if device.type != "cuda":
+        return True
+
+    if torch.version.hip is not None:
+        # ROCm devices may have different compute capability codes
+        return False
+
+    return torch.cuda.get_device_capability(device) >= (major, minor)
+
+
+@retry_on_connect_failures
+def create_tcp_store(
+    addr="localhost",
+    world_size=1,
+    is_master=True,
+    timeout=timedelta(minutes=5),
+    wait_for_workers=True,
+    jit_class=False,
+    use_libuv=True,
+):
+    """
+    Creates a TCP store. Retries if the chosen port is already in use.
+    """
+    port = find_free_port()
+    if jit_class:
+        timeout_millisecond = int(timeout / timedelta(milliseconds=1))
+        return torch.classes.dist_c10d.TCPStore(
+            addr, port, world_size, is_master, timeout_millisecond
+        )
+    else:
+        return c10d.TCPStore(
+            addr,
+            port,
+            world_size,
+            is_master,
+            wait_for_workers=wait_for_workers,
+            use_libuv=use_libuv,
+        )
+
+
+if TEST_WITH_TSAN:
+    # TSAN runs much slower.
+    TIMEOUT_DEFAULT = 500
+else:
+    TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300"))
+TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
+
+
+# https://github.com/pytorch/pytorch/issues/75665
+if TEST_WITH_ROCM:
+    TIMEOUT_OVERRIDE["test_join_kwargs"] = 200
+
+
+def create_device(interface=None, lazy_init: bool = False):
+    if sys.platform == "win32" or interface is None:
+        return c10d.ProcessGroupGloo.create_device(
+            hostname="127.0.0.1", lazy_init=lazy_init
+        )
+    else:
+        return c10d.ProcessGroupGloo.create_device(
+            interface=interface, lazy_init=lazy_init
+        )
+
+
+def get_timeout(test_id) -> int:
+    return TIMEOUT_OVERRIDE.get(test_id.split(".")[-1], TIMEOUT_DEFAULT)
+
+
+@contextmanager
+def captured_output():
+    new_out, new_err = StringIO(), StringIO()
+    old_out, old_err = sys.stdout, sys.stderr
+    try:
+        sys.stdout, sys.stderr = new_out, new_err
+        yield sys.stdout, sys.stderr
+    finally:
+        sys.stdout, sys.stderr = old_out, old_err
+
+
+def simple_sparse_reduce_tests(rank: int, world_size: int, num_inputs: int = 1):
+    """
+    Generate a number of basic test cases for sparse reduction.
+    These cover tensors with a varying number of sparse dimensions and a varying
+    number of dense dimensions. The only reduction operation we support is sum.
+    """
+
+    def generate(rank: int, world_size: int, sparse_dims: int = 1, dense_dims: int = 0):
+        # First sparse dimension is [0..rank].
+        # Subsequent dimensions are always 0, so we know there is
+        # a non-empty intersection between any two sparse tensors.
+        indices = torch.reshape(torch.arange(rank + 1), (1, rank + 1))
+        shape = [world_size] + [2 for _ in range(dense_dims)]
+        for _ in range(sparse_dims - 1):
+            indices = torch.cat((indices, torch.zeros(1, rank + 1)))
+            shape.append(world_size)
+        values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
+        return torch.sparse_coo_tensor(indices, values, shape)
+
+    def compute_sum(fn, world_size: int):
+        return reduce(
+            operator.add, [fn(rank, world_size) for rank in range(world_size)]
+        )
+
+    return [
+        (
+            [
+                fn(num_inputs * rank + i, num_inputs * world_size)
+                for i in range(num_inputs)
+            ],
+            [compute_sum(fn, num_inputs * world_size) for i in range(num_inputs)],
+        )
+        for fn in [
+            partial(generate, sparse_dims=1),
+            partial(generate, sparse_dims=2),
+            partial(generate, sparse_dims=3),
+            partial(generate, dense_dims=1),
+            partial(generate, dense_dims=2),
+            partial(generate, dense_dims=3),
+        ]
+    ]
+
+
+# HELPER FOR MULTIGPU TESTS
+def init_multigpu_helper(world_size: int, backend: str):
+    """Multigpu tests are designed to simulate the multi nodes with multi
+    GPUs on each node. Nccl backend requires equal #GPUs in each process.
+    On a single node, all visible GPUs are evenly
+    divided to subsets, each process only uses a subset.
+    """
+    nGPUs = torch.cuda.device_count()
+    if TEST_HPU:
+        nGPUs = torch.hpu.device_count()
+    if TEST_XPU:
+        nGPUs = torch.xpu.device_count()
+    visible_devices = range(nGPUs)
+
+    # If rank is less than or equal to number of available GPU's
+    # then each rank can be mapped to corresponding GPU.
+    nGPUs_per_process = 1
+    if world_size > nGPUs:
+        nGPUs_per_process = nGPUs // world_size
+    rank_to_GPU = {
+        i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process])
+        for i in range(world_size)
+    }
+    return rank_to_GPU
+
+
+tmp_dir: Optional[tempfile.TemporaryDirectory] = None
+
+
+def initialize_temp_directories(init_method: Optional[str] = None) -> None:
+    global tmp_dir
+    tmp_dir = tempfile.TemporaryDirectory()
+    os.environ["TEMP_DIR"] = tmp_dir.name
+    os.mkdir(os.path.join(tmp_dir.name, "barrier"))
+    os.mkdir(os.path.join(tmp_dir.name, "test_dir"))
+    init_dir_path = os.path.join(tmp_dir.name, "init_dir")
+    os.mkdir(init_dir_path)
+    # Set init method if specified.
+    if init_method is not None:
+        os.environ["INIT_METHOD"] = init_method
+    else:
+        os.environ["INIT_METHOD"] = FILE_SCHEMA + os.path.join(
+            init_dir_path, "shared_init_file"
+        )
+
+
+def cleanup_temp_dir() -> None:
+    if tmp_dir is not None:
+        tmp_dir.cleanup()
+
+
+# Most tests operate with this worldsize
+DEFAULT_WORLD_SIZE = 4
+
+# [How does MultiProcessTestCase work?]
+# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
+# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an
+# example which inherits from this class. Its `Setup()` methods calls into
+# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()`
+# subprocesses. During the spawn, the main process passes the test name to
+# subprocesses, and the name is acquired from self.id(). The subprocesses
+# then use the provided test function name to retrieve the function attribute
+# from the test instance and run it. The main process simply waits for all
+# subprocesses to join.
+
+
+class MultiProcessTestCase(TestCase):
+    MAIN_PROCESS_RANK = -1
+    # This exit code is used to indicate that the test code had an error and
+    # exited abnormally. There are certain tests that might use sys.exit() to
+    # simulate failures and in those cases, we can't have an exit code of 0,
+    # but we still want to ensure we didn't run into any other errors.
+    TEST_ERROR_EXIT_CODE = 10
+
+    # do not early terminate for distributed tests.
+    def _should_stop_test_suite(self) -> bool:
+        return False
+
+    # Many test cases init a process group but do not destroy it.  This property
+    # determines whether this base test class should call
+    # `destroy_process_group` on behalf of the test. Its value is customizable
+    # by derived TestCase's but it is a pan-TestCase value (cannot be customized
+    # for each test).
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        return True
+
+    @property
+    def world_size(self) -> int:
+        return DEFAULT_WORLD_SIZE
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_PROCESS_RANK:
+                self._join_processes(fn)
+            else:
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    # The main process spawns N subprocesses that run the test.
+    # Constructor patches current instance test method to
+    # assume the role of the main process and join its subprocesses,
+    # or run the underlying test function.
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self.join_or_run(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
+
+    def setUp(self) -> None:
+        super().setUp()
+
+        # Used for tests that are expected to return a non-0 exit code, such as
+        # SIGABRT thrown by watchdog.
+        self.special_return_code_checks: dict = {}
+
+        # Used for tests that may return any exit code, which makes it hard to
+        # check. This is rare, use with caution.
+        self.skip_return_code_checks: list = []
+
+        self.processes = []  # type: ignore[var-annotated]
+        self.rank = self.MAIN_PROCESS_RANK
+        with tempfile.NamedTemporaryFile(delete=False) as f:
+            self.file_name = f.name
+        # pid to pipe consisting of error message from process.
+        self.pid_to_pipe = {}  # type: ignore[var-annotated]
+
+    def tearDown(self) -> None:
+        super().tearDown()
+        for p in self.processes:
+            p.terminate()
+        # Each Process instance holds a few open file descriptors. The unittest
+        # runner creates a new TestCase instance for each test method and keeps
+        # it alive until the end of the entire suite. We must thus reset the
+        # processes to prevent an effective file descriptor leak.
+        self.processes = []
+
+    def _current_test_name(self) -> str:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        return self.id().split(".")[-1]
+
+    def _start_processes(self, proc) -> None:
+        self.processes = []
+        for rank in range(int(self.world_size)):
+            parent_conn, child_conn = torch.multiprocessing.Pipe()
+            process = proc(
+                target=self.__class__._run,
+                name="process " + str(rank),
+                args=(
+                    rank,
+                    self._current_test_name(),
+                    self.file_name,
+                    child_conn,
+                ),
+                kwargs={
+                    "fake_pg": getattr(self, "fake_pg", False),
+                },
+            )
+            process.start()
+            logger.info("Started process %s with pid %s", rank, process.pid)
+            self.pid_to_pipe[process.pid] = parent_conn
+            self.processes.append(process)
+
+    def _spawn_processes(self) -> None:
+        try:
+            torch.multiprocessing.set_start_method("spawn")
+        except RuntimeError:
+            pass
+
+        proc = torch.multiprocessing.get_context("spawn").Process
+        self._start_processes(proc)
+
+    class Event(Enum):
+        GET_TRACEBACK = 1
+
+    @staticmethod
+    def _event_listener(parent_pipe, signal_pipe, rank: int):
+        logger.debug("Starting event listener thread for rank %s", rank)
+        while True:
+            ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe])
+
+            if parent_pipe in ready_pipes:
+                if parent_pipe.closed:
+                    logger.debug(
+                        "Pipe closed for process %s, stopping event listener thread",
+                        rank,
+                    )
+                    return
+
+                event = parent_pipe.recv()
+                logger.info("Received event %s on process %s", event, rank)
+
+                if event == MultiProcessTestCase.Event.GET_TRACEBACK:
+                    # Return traceback to the parent process.
+                    with tempfile.NamedTemporaryFile(mode="r+") as tmp_file:
+                        faulthandler.dump_traceback(tmp_file)
+                        # Flush buffers and seek to read from the beginning
+                        tmp_file.flush()
+                        tmp_file.seek(0)
+                        parent_pipe.send(tmp_file.read())
+
+                        logger.info("Process %s sent traceback", rank)
+
+            if signal_pipe in ready_pipes:
+                return
+
+    @classmethod
+    def _run(
+        cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
+    ) -> None:
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        self.run_test(test_name, parent_pipe)
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        # Start event listener thread.
+        signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False)
+        event_listener_thread = threading.Thread(
+            target=MultiProcessTestCase._event_listener,
+            args=(parent_pipe, signal_recv_pipe, self.rank),
+            daemon=True,
+        )
+        event_listener_thread.start()
+        if sys.platform != "win32" and sys.platform != "darwin":
+            # Register signal handler to dump stack traces on FATALs.
+            # Windows and MacOS do not support the signal handlers.
+            torch._C._set_print_stack_traces_on_fatal_signal(True)
+        # Show full C++ stacktraces when a Python error originating from C++ is raised.
+        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
+        common_utils.set_rng_seed()
+
+        # self.id() == e.g. '__main__.TestDistributed.test_get_rank'
+        # We're retrieving a corresponding test and executing it.
+        try:
+            getattr(self, test_name)()
+        except unittest.SkipTest as se:
+            logger.info(  # noqa: G200
+                "Process %s skipping test %s for following reason: %s",
+                self.rank,
+                test_name,
+                str(se),
+            )
+            sys.exit(TEST_SKIPS["generic"].exit_code)
+        except Exception:
+            logger.error(
+                "Caught exception: \n%s exiting process %s with exit code: %s",
+                traceback.format_exc(),
+                self.rank,
+                MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
+            )
+            # Send error to parent process.
+            parent_pipe.send(traceback.format_exc())
+            sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
+        finally:
+            if signal_send_pipe is not None:
+                signal_send_pipe.send(None)
+
+            assert event_listener_thread is not None
+            event_listener_thread.join()
+            # Close pipe after done with test.
+            parent_pipe.close()
+
+        if self.destroy_pg_upon_exit:
+            try:
+                # Some tests do destroy the pgs, and destroy can't be called twice.
+                # This avoids spewing warnings about improperly shutting down.
+                c10d.destroy_process_group()
+            except (AssertionError, ValueError):
+                pass
+
+    def _get_timedout_process_traceback(self) -> None:
+        pipes = []
+        for i, process in enumerate(self.processes):
+            if process.exitcode is None:
+                pipe = self.pid_to_pipe[process.pid]
+                try:
+                    pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK)
+                    pipes.append((i, pipe))
+                except ConnectionError:
+                    logger.exception(
+                        "Encountered error while trying to get traceback for process %s",
+                        i,
+                    )
+
+        # Wait for results.
+        for rank, pipe in pipes:
+            try:
+                # Wait for traceback
+                if pipe.poll(5):
+                    if pipe.closed:
+                        logger.info(
+                            "Pipe closed for process %s, cannot retrieve traceback",
+                            rank,
+                        )
+                        continue
+
+                    traceback = pipe.recv()
+                    logger.error(
+                        "Process %s timed out with traceback: \n\n%s", rank, traceback
+                    )
+                else:
+                    logger.error(
+                        "Could not retrieve traceback for timed out process: %s", rank
+                    )
+            except ConnectionError:
+                logger.exception(
+                    "Encountered error while trying to get traceback for process %s",
+                    rank,
+                )
+
+    def _join_processes(self, fn) -> None:
+        timeout = get_timeout(self.id())
+        start_time = time.time()
+        subprocess_error = False
+        try:
+            while True:
+                # check to see if any subprocess exited with an error early.
+                for i, p in enumerate(self.processes):
+                    # This is the exit code processes exit with if they
+                    # encountered an exception.
+                    if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
+                        print(
+                            f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes."
+                        )
+                        active_children = torch.multiprocessing.active_children()
+                        for ac in active_children:
+                            ac.terminate()
+                        subprocess_error = True
+                        break
+                if subprocess_error:
+                    break
+                # All processes have joined cleanly if they all a valid exitcode
+                if all(p.exitcode is not None for p in self.processes):
+                    break
+                # Check if we should time out the test. If so, we terminate each process.
+                elapsed = time.time() - start_time
+                if elapsed > timeout:
+                    self._get_timedout_process_traceback()
+                    print(
+                        f"Timing out after {timeout} seconds and killing subprocesses."
+                    )
+                    for p in self.processes:
+                        p.terminate()
+                    break
+                # Sleep to avoid excessive busy polling.
+                time.sleep(0.1)
+
+            elapsed_time = time.time() - start_time
+            self._check_return_codes(fn, elapsed_time)
+        finally:
+            # Close all pipes
+            for pipe in self.pid_to_pipe.values():
+                pipe.close()
+
+    def _check_return_codes(self, fn, elapsed_time) -> None:
+        """
+        Checks that the return codes of all spawned processes match, and skips
+        tests if they returned a return code indicating a skipping condition.
+        """
+        # If no processes are spawned, there is nothing to check.
+        if not self.processes:
+            logger.warning(
+                "Note: no subprocesses were spawned, test was likely skipped."
+            )
+            return
+
+        first_process = self.processes[0]
+        # first, we check if there are errors in actual processes
+        # (via TEST_ERROR_EXIT CODE), and raise an exception for those.
+        # the reason we do this is to attempt to raise a more helpful error
+        # message than "Process x terminated/timed out"
+        # TODO: we should pipe the exception of the failed subprocess here.
+        # Currently, the actual exception is displayed as a logging output.
+        errored_processes = [
+            (i, p)
+            for i, p in enumerate(self.processes)
+            if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE
+        ]
+        if errored_processes:
+            error = ""
+            for i, process in errored_processes:
+                # Get error from pipe.
+                error_message = self.pid_to_pipe[process.pid].recv()
+                error += (
+                    f"Process {i} exited with error code {MultiProcessTestCase.TEST_ERROR_EXIT_CODE} "
+                    f"and exception:\n{error_message}\n"
+                )
+
+            raise RuntimeError(error)
+        # If no process exited uncleanly, we check for timeouts, and then ensure
+        # each process exited cleanly.
+        for i, p in enumerate(self.processes):
+            if p.exitcode is None:
+                raise RuntimeError(
+                    f"Process {i} terminated or timed out after {elapsed_time} seconds"
+                )
+
+        # Skip the test return code check
+        if fn in self.skip_return_code_checks:
+            return
+
+        for skip in TEST_SKIPS.values():
+            if first_process.exitcode == skip.exit_code:
+                if IS_SANDCASTLE:
+                    # Don't use unittest.skip to skip the test on sandcastle
+                    # since it creates tasks for skipped tests assuming there
+                    # is some follow-up needed. Instead just "pass" the test
+                    # with an appropriate message.
+                    logger.info(
+                        "Skipping %s on sandcastle for the following reason: %s",
+                        self.id(),
+                        skip.message,
+                    )
+                    return
+                else:
+                    raise unittest.SkipTest(skip.message)
+
+        # In most cases, we expect test to return exit code 0, standing for success.
+        expected_return_code = 0
+        # In some negative tests, we expect test to return non-zero exit code,
+        # such as watchdog throwing SIGABRT.
+        if fn in self.special_return_code_checks:
+            expected_return_code = self.special_return_code_checks[fn]
+
+        self.assertEqual(
+            first_process.exitcode,
+            expected_return_code,
+            msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
+        )
+
+    @property
+    def is_master(self) -> bool:
+        return self.rank == 0
+
+
+# Utility base class for distributed Multi Process Test cases
+# This abstracts the PG creation and deletion, the backends are selected based
+# on device type. The tests functions can be instantiated per device type using
+# common_device_type.instantiate_device_type_tests
+# other backends can add entry in backend() function
+class DistributedTestBase(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        os.environ["WORLD_SIZE"] = str(self.world_size)
+        self._spawn_processes()
+
+    def tearDown(self):
+        try:
+            torch.distributed.destroy_process_group()
+        except AssertionError:
+            pass
+        try:
+            os.remove(self.file_name)
+        except OSError:
+            pass
+
+    def backend(self, device) -> str:
+        if "cuda" in device:
+            return "nccl"
+        elif "hpu" in device:  # intel gaudi
+            return "hccl"
+        elif "xpu" in device:
+            return "xccl"
+        else:
+            return "gloo"
+
+    def create_pg(self, device, world_size=None):
+        if world_size is None:
+            world_size = self.world_size
+        num_visible_devices = torch.get_device_module(device).device_count()
+        store = torch.distributed.FileStore(self.file_name, num_visible_devices)
+        torch.distributed.init_process_group(
+            backend=self.backend(device),
+            world_size=world_size,
+            rank=self.rank,
+            store=store,
+        )
+        if "nccl" in self.backend(device) or "xccl" in self.backend(device):
+            torch.accelerator.set_device_index(self.rank)
+        return torch.distributed.distributed_c10d._get_default_group()
+
+    def rank_to_device(self, device):
+        num_visible_devices = torch.get_device_module(device).device_count()
+        return {i: [i % num_visible_devices] for i in range(self.world_size)}
+
+
+def run_subtests(
+    cls_inst,
+    subtest_config: dict[str, list[Any]],
+    test_fn: Callable,
+    *test_args,
+    **test_kwargs: Any,
+):
+    """
+    Runs a test function given by ``test_fn`` as a subtest according to the
+    configurations specified by ``subtest_config``. This amortizes the
+    costly setup overhead (including process spawn and initializing the
+    process group) over the subtests.
+
+    Args:
+        subtest_config (Dict[str, List[Any]]): A mapping from subtest
+            keyword argument name to a list of its possible values.
+        test_fn (Callable): A callable that runs the actual test.
+        test_args: Positional arguments to pass to ``test_fn``.
+        test_kwargs: Keyword arguments to pass to ``test_fn``.
+    """
+    # Convert the config mapping to a list to have a fixed order
+    subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items())
+    subtest_config_keys: list[str] = [item[0] for item in subtest_config_items]
+    subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items]
+    for values in itertools.product(*subtest_config_values):
+        # Map keyword to chosen value
+        subtest_kwargs = dict(zip(subtest_config_keys, values, strict=True))
+        with cls_inst.subTest(**subtest_kwargs):
+            torch._dynamo.reset()
+            test_fn(*test_args, **test_kwargs, **subtest_kwargs)
+            torch._dynamo.reset()
+        c10d.barrier()
+
+
+@functools.cache
+def has_efa() -> bool:
+    """
+    If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has
+    Libfabric EFA interfaces and EFA software components installed,
+    see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start.html.
+    """
+
+    try:
+        return (
+            subprocess.run(
+                ["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False
+            ).returncode
+            == 0
+        )
+    except FileNotFoundError:
+        pass
+    return False
+
+
+def tp_transports():
+    """
+    If the machine has Libfabric EFA interfaces and EFA software components installed it may cause
+    'RuntimeError: In operator() at tensorpipe/common/ibv.h:172 "": Operation not supported' if tensorpipe
+    uses InfiniBand transport, so we exclude it from tensorpipe transports,
+    see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
+    """
+    return ["shm", "uv"] if has_efa() else None
+
+
+def spawn_threads_and_init_comms(
+    func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE
+):
+    """
+    Wrapper to use with a test method
+    """
+    if func is None:
+        return partial(
+            spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
+        )
+
+    def _run_test_method_with_multi_threads(world_size, callback):
+        world = _install_threaded_pg()
+        global_store = c10d.HashStore()
+
+        def world_is_valid():
+            return world == c10d.distributed_c10d._world
+
+        def worker(rank, world_pg, store):
+            c10d.init_process_group(
+                backend="threaded", rank=rank, world_size=world_size, store=store
+            )
+            try:
+                callback()
+            except BaseException as ex:  # noqa: B036
+                # Exceptions are handled in MultiThreadedTestCase
+                MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
+                ProcessLocalGroup.exception_handle(
+                    ex
+                )  # trigger _terminate event and awaken worker threads
+            finally:
+                if world_is_valid():
+                    c10d.destroy_process_group()
+
+        threads = []
+        for rank in range(world_size):
+            t = threading.Thread(target=worker, args=(rank, world, global_store))
+            t.start()
+            threads.append(t)
+
+        return threads
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        # TODO: get test name from kwargs
+        torch._C._distributed_c10d._set_thread_isolation_mode(True)
+        try:
+            threads = _run_test_method_with_multi_threads(
+                world_size, lambda: func(self, *args, **kwargs)
+            )
+            # join and error handling
+            MultiThreadedTestCase._join_threads(threads, func)
+        finally:
+            torch._C._distributed_c10d._set_thread_isolation_mode(False)
+
+    return wrapper
+
+
+class MultiThreadedTestCase(TestCase):
+    """
+    Test runner that runs all tests with the in-proc process group using
+    multiple threads with the threaded process group.
+
+    Each test spawns world_size threads and run the test method in each thread.
+
+    Difference from regular MultiProcess test runner:
+    Must explicitly defines SetUp and call self._spawn_threads() to run the tests.
+    Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
+        to set up / tear down each thread when running each test.
+    No global state possible
+        How bad of a limitation is this?
+    """
+
+    exception_queue = queue.Queue()
+
+    MAIN_THREAD_RANK = -1
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_THREAD_RANK:
+                self._join_threads(self.threads, fn)
+            else:
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self.join_or_run(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
+
+    def perThreadSetUp(self):
+        # super().setUp()  # TestCase.setUp() calls torch.manual_seed()
+        pass
+
+    def perThreadTearDown(self):
+        pass
+
+    def setUp(self) -> None:
+        """
+        setUp only set up things in the main thread, if you want to configure things
+        in the spawned threads, use perThreadSetUp
+        """
+        super().setUp()
+        self.rank = self.MAIN_THREAD_RANK
+        self.threads = []
+        # Show full C++ stacktraces when a Python error originating from C++ is raised.
+        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
+
+    def tearDown(self):
+        """
+        tearDown only set up things in the main thread, if you want to configure things
+        in the spawned threads, use perThreadTearDown
+        """
+        super().tearDown()
+        self.threads = []
+
+    def _spawn_threads(self):
+        """
+        class method to spawn threads and run test, use this method in the SetUp of your TestCase
+        """
+        torch._C._distributed_c10d._set_thread_isolation_mode(True)
+        test_name = self._current_test_name
+        # for each test case, we need to create thread local world, and a global store
+        world = _install_threaded_pg()
+        self.__class__.global_store = c10d.HashStore()
+
+        def world_is_valid():
+            return world == c10d.distributed_c10d._world
+
+        if not world_is_valid():
+            raise RuntimeError("Invalid world")
+
+        for rank in range(self.world_size):
+            t = threading.Thread(
+                target=self.__class__._run, args=(test_name, rank, self.world_size)
+            )
+            t.start()
+            self.threads.append(t)
+
+    @classmethod
+    def _run(cls, test_name, rank, world_size, **kwargs):
+        self = cls(test_name)
+        self.rank = rank
+
+        # precision/rel_tol is a thread-local setting since it may be overridden per test, need to make
+        # every thread have the same value. This would be relevant when we use op db tests, where it
+        # needs those states to be set i.e. using instantiate_device_type_tests()
+        # TODO: figure out a better way to do this
+        if hasattr(self, "_tls"):
+            self._tls = threading.local()
+            self._tls.precision = TestCase._precision
+            self._tls.rel_tol = TestCase._rel_tol
+
+        self.run_test_with_threaded_pg(test_name, rank, world_size)
+
+    def run_test_with_threaded_pg(self, test_name, rank, world_size):
+        """
+        Run the current test associated with `test_name` using the threaded process group.
+        """
+        c10d.init_process_group(
+            backend="threaded",
+            rank=rank,
+            world_size=world_size,
+            store=self.__class__.global_store,
+        )
+        self.perThreadSetUp()
+
+        try:
+            getattr(self, test_name)()
+        except BaseException as ex:  # noqa: B036
+            self.exception_queue.put((rank, sys.exc_info()))
+            ProcessLocalGroup.exception_handle(
+                ex
+            )  # trigger _terminate event and awaken worker threads
+        finally:
+            c10d.destroy_process_group()
+            self.perThreadTearDown()
+
+    @classmethod
+    def _join_threads(cls, threads, fn):
+        timeout = TIMEOUT_DEFAULT
+        try:
+            for idx, thread in enumerate(threads):
+                thread.join(max(0, timeout))
+                if thread.is_alive():
+                    MultiThreadedTestCase.exception_queue.put(
+                        (
+                            idx,
+                            (
+                                TimeoutError,
+                                TimeoutError(
+                                    f"Rank failed to join in under {timeout} seconds"
+                                ),
+                                None,
+                            ),
+                        )
+                    )
+            ProcessLocalGroup.reset()
+            failed_ranks = []
+            while not cls.exception_queue.empty():
+                failure = cls.exception_queue.get()
+                failed_ranks.append(failure)
+        finally:
+            _uninstall_threaded_pg()
+            torch._C._distributed_c10d._set_thread_isolation_mode(False)
+
+        cls._check_return_codes(failed_ranks, timeout, fn)
+
+    @classmethod
+    def _check_return_codes(cls, failed_ranks, timeout, fn):
+        # Print based on exceptions raised from threads
+        #   SkipTest: print info for each thread
+        #   TimeoutError: raise RuntimeError for any timed out thread
+        #   Normal Exception: print error for each thread that raises exception
+        #   and raise a RuntimeError
+        error_msg = ""
+        skip_code = -1
+        for rank, exc_info in failed_ranks:
+            exc = exc_info[1]
+            if isinstance(exc, unittest.SkipTest):
+                logger.info(
+                    "Thread %s skipping test %s for following reason: %s",
+                    rank,
+                    fn,
+                    str(exc),
+                )
+                if skip_code < 0:
+                    skip_code = TEST_SKIPS["generic"].exit_code
+            elif isinstance(exc, TimeoutError):
+                msg = f"Thread {rank} terminated or timed out after {timeout} seconds\n"
+                logger.error(msg)
+                raise RuntimeError(msg)
+            elif isinstance(exc, Exception):
+                msg = "".join(traceback.format_exception(*exc_info))
+                logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
+                error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
+            elif isinstance(exc, SystemExit):
+                if type(exc.code) is int and skip_code < 0:
+                    skip_code = exc.code
+
+        # check exceptions
+        if len(error_msg) > 0:
+            raise RuntimeError(error_msg)
+        # check skip
+        if skip_code > 0:
+            for skip in TEST_SKIPS.values():
+                if skip_code == skip.exit_code:
+                    if IS_SANDCASTLE:
+                        # "pass" the test with an appropriate message.
+                        logger.info(
+                            "Skipping %s on sandcastle for the following reason: %s",
+                            fn,
+                            skip.message,
+                        )
+                        return
+                    else:
+                        raise unittest.SkipTest(skip.message)
+
+    @property
+    def world_size(self) -> int:
+        return DEFAULT_WORLD_SIZE
+
+    @property
+    def _current_test_name(self) -> str:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        return self.id().split(".")[-1]
+
+    def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
+        """
+        The reason why we have this util function instead of
+        self.assertEqual is all threads are sharing one CPU RNG
+        so the assertion result is only reliable on rank 0
+        """
+        if self.rank == rank:
+            self.assertEqual(x, y, msg)
+
+    def assertNotEqualOnRank(self, x, y, msg=None, *, rank=0):
+        if self.rank == rank:
+            self.assertNotEqual(x, y)
+
+
+class SaveForwardInputsModule(nn.Module):
+    def __init__(
+        self,
+        forward_inputs: dict[nn.Module, torch.Tensor],
+        cast_forward_inputs: bool,
+    ) -> None:
+        super().__init__()
+        self.l = nn.Linear(100, 100)
+        self.forward_inputs = forward_inputs
+        self.cast_forward_inputs = cast_forward_inputs
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        self.forward_inputs[self] = x
+        return self.l(x.to(self.l.weight.dtype) if self.cast_forward_inputs else x)
+
+
+class SaveForwardInputsModel(nn.Module):
+    def __init__(
+        self,
+        forward_inputs: dict[nn.Module, torch.Tensor],
+        cast_forward_inputs: bool,
+    ) -> None:
+        super().__init__()
+        self.c1 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
+        self.c2 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
+        self.forward_inputs = forward_inputs
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        self.forward_inputs[self] = x
+        return self.c2(self.c1(x))
+
+
+@contextmanager
+def _dynamo_dist_per_rank_init(
+    rank, world_size, backend=None, init_pg=True, fake_pg=False
+):
+    # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
+    # Just manually implement the most important part of the dynamo behavior to reset/clear.
+    if not fake_pg:
+        torch.accelerator.set_device_index(rank)
+
+    device_type = (
+        acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
+    )
+    if backend is None:
+        backend = c10d.get_default_backend_for_device(device_type)
+
+    os.environ["MASTER_ADDR"] = "localhost"
+    os.environ["MASTER_PORT"] = "6789"
+    if init_pg:
+        if fake_pg:
+            store = torch.testing._internal.distributed.fake_pg.FakeStore()
+            c10d.init_process_group(
+                backend="fake",
+                world_size=world_size,
+                rank=rank,
+                store=store,
+            )
+        else:
+            c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
+    torch._dynamo.reset()
+    torch._dynamo.utils.counters.clear()
+    try:
+        yield
+    finally:
+        torch._dynamo.reset()
+        torch._dynamo.utils.counters.clear()
+        if init_pg:
+            c10d.destroy_process_group()
+
+
+class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
+    """
+    Test harness for single-process dynamo distributed tests,
+    initializes dist process group.
+
+    Prefer this for simple tests, as it's easier to debug.
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        # _exit_stack is set up in TestCase
+        cls._exit_stack.enter_context(
+            patch.dict(
+                os.environ,
+                {
+                    "MASTER_ADDR": "localhost",
+                    "MASTER_PORT": "12355",
+                },
+            )
+        )
+        cls.rank = 0
+        device = torch.accelerator.current_accelerator().type
+        cls.device = f"{device}:{cls.rank}"
+        cls.device_ids = None if device in cls.device else [cls.rank]
+        c10d.init_process_group(
+            c10d.get_default_backend_for_device(device), rank=cls.rank, world_size=1
+        )
+
+    @classmethod
+    def tearDownClass(cls):
+        c10d.destroy_process_group()
+        super().tearDownClass()
+
+
+class DynamoDistributedMultiProcTestCase(DistributedTestBase):
+    """
+    Use this for tests that actually run on multiple GPUs.
+
+    Decorate tests with @skip_if_lt_x_gpu(ngpu)
+
+    Note: MultiProcTestCase spawns processes per test and is slow.
+    Prefer MultiThreadedTestCase for most tests. Perhaps use this one
+    sparingly for integration tests.
+    """
+
+    @property
+    def world_size(self) -> int:
+        return torch.accelerator.device_count()
+
+    @classmethod
+    def _run(
+        cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
+    ) -> None:
+        trace_log.addHandler(logging.NullHandler())
+
+        # The rest is copypasta from MultiProcessTestCase._run
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        self.run_test(test_name, parent_pipe)
+
+
+class MultiProcContinuousTest(TestCase):
+    # Class variables:
+    MAIN_PROCESS_RANK = -1
+    # number of test processes
+    world_size: int = -2  # unset state
+    # rank of the current process
+    rank: int = -2  # unset state
+    # Rendezvous file
+    rdvz_file: Optional[str] = None
+    # timeout configured per class
+    timeout: timedelta = timedelta(seconds=120)
+    # Poison pill for rest of tests if one of them fails
+    poison_pill: bool = False
+
+    @classmethod
+    def backend_str(cls) -> Optional[str]:
+        """
+        ProcessGroup backend str.
+        To be customized by sub test classes, e.g. "nccl".
+        Otherwise we return None -- lazily decided by tensor.
+        """
+        return None
+
+    # Please override if you intend to test on specific device type
+    @classmethod
+    def device_type(cls) -> str:
+        curr_device = torch.accelerator.current_accelerator()
+        if curr_device is None:
+            return "cpu"
+        return curr_device.type
+
+    @classmethod
+    def opts(cls, high_priority_stream=False):
+        """
+        ProcessGroup init options.
+        To be customized by sub test classes, e.g. ProcessGroupNCCLOpTest
+        Here we return None.
+        """
+        return None
+
+    @classmethod
+    def _init_pg(cls, rank, world_size, rdvz_file):
+        assert rdvz_file is not None
+        # rank should be local_rank for tests running on <= 8 gpus which is how all these tests are designed
+        # and we expect LOCAL_RANK set by torchrun. Setting it lets init_device_mesh set the device without
+        # issuing a warning
+        os.environ["LOCAL_RANK"] = str(rank)
+        store = c10d.FileStore(rdvz_file, world_size)
+        # create nccl processgroup with opts
+        c10d.init_process_group(
+            backend=cls.backend_str(),
+            world_size=world_size,
+            rank=rank,
+            store=store,
+            pg_options=cls.opts(),
+            timeout=cls.timeout,
+        )
+        cls.pg = c10d.distributed_c10d._get_default_group()
+
+    @classmethod
+    def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        test_name = test_id.rsplit(".", maxsplit=1)[-1]
+        # Get the test function from the test class
+        self = cls(test_name)
+        self.rank = cls.rank
+        self.world_size = cls.world_size
+        test_fn = getattr(self, test_name)
+
+        # Ensure all the ranks use the same seed.
+        common_utils.set_rng_seed()
+
+        # Run the test function
+        test_fn(**kwargs)
+
+    @classmethod
+    def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
+        raised_exception = False
+        # Sub tests are going to access these values, check first
+        assert 0 <= rank < world_size
+        # set class variables for the test class
+        cls.rank = rank
+        cls.world_size = world_size
+
+        # Initialize the process group
+        cls._init_pg(rank, world_size, rdvz_file)
+
+        # End of bootstrap
+        logger.debug("Setup complete")
+
+        # Loop forever, waiting for a test name to run
+        while True:
+            test_id = task_queue.get()
+            logger.debug(f"Got test {test_id}")  # noqa: G004
+            # None means exit
+            if test_id is None:
+                break
+
+            # Run the test
+            try:
+                cls._run_test_given_id(test_id)
+                completion_queue.put(test_id)
+            except BaseException as ex:  # noqa: B036
+                if isinstance(ex, SystemExit):
+                    # Get exit code from the process
+                    exit_code = getattr(ex, "code", None)
+
+                    # Look up exit code in TEST_SKIPS to see if it is a valid skip
+                    skip_entry = next(
+                        (v for v in TEST_SKIPS.values() if v.exit_code == exit_code),
+                        None,
+                    )
+
+                    # If we found an entry, we want to skip the test and the object back to the main process
+                    if skip_entry:
+                        completion_queue.put(unittest.SkipTest(skip_entry.message))
+                        # Skip exception handling below, move to main thread for processing the skip
+                        continue
+
+                raised_exception = True
+                # Send the exception and stack trace back to the dispatcher
+                exc_info = sys.exc_info()
+                tb_str = "".join(traceback.format_exception(*exc_info))
+                # Create a new exception with the original exception and traceback
+                enhanced_ex = RuntimeError(f"Exception in worker process:\n{tb_str}")
+                enhanced_ex.__cause__ = ex
+                completion_queue.put(enhanced_ex)
+
+        # Termination
+        logger.debug("Terminating ...")
+        # Calling destroy_process_group when workers have exceptions
+        # while others are doing collectives will cause a deadlock since
+        # it waits for enqueued collectives to finish.
+        # Only call this on a clean exit path
+        if not raised_exception:
+            c10d.destroy_process_group()
+
+    @classmethod
+    def _spawn_processes(cls, world_size) -> None:
+        cls.processes = []
+        cls.task_queues = []
+        cls.completion_queues = []
+        # Need a rendezvous file for `init_process_group` purpose.
+        with tempfile.NamedTemporaryFile(delete=False) as f:
+            cls.rdvz_file = f.name
+
+        # CUDA multiprocessing requires spawn instead of fork, to make sure
+        # child processes have their own memory space.
+        try:
+            torch.multiprocessing.set_start_method("spawn")
+        except RuntimeError:
+            # The start method has already been set
+            pass
+
+        for rank in range(int(world_size)):
+            task_queue = torch.multiprocessing.Queue()
+            completion_queue = torch.multiprocessing.Queue()
+            process = torch.multiprocessing.Process(
+                target=cls._worker_loop,
+                name="process " + str(rank),
+                daemon=True,  # so that child processes will exit if parent decides to terminate
+                args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
+            )
+            process.start()
+            cls.processes.append(process)
+            cls.task_queues.append(task_queue)
+            cls.completion_queues.append(completion_queue)
+            logger.debug("Started process %s with pid %s", rank, process.pid)  # noqa: UP031
+
+    @classmethod
+    def setUpClass(cls):
+        """
+        Class-scope test fixture. Run once for entire test class, before any test starts.
+        Set up the process group.
+        """
+        super().setUpClass()
+
+        # Use device count as world size
+        device_type = cls.device_type()
+        # If world_size is not set, use device count
+        if cls.world_size == -2:
+            cls.world_size = torch.get_device_module(device_type).device_count()
+            if cls.world_size == 0:
+                raise unittest.SkipTest(f"No {device_type} devices available")
+
+        logger.info(
+            f"Testing class {cls.__name__} on {cls.world_size} {device_type}"  # noqa: G004
+        )
+
+        cls._spawn_processes(cls.world_size)
+
+    @classmethod
+    def tearDownClass(cls):
+        """
+        Class-scope test fixture. Run once for entire test class, after all tests finish.
+        Tear down the process group.
+        """
+        logger.debug(f"Joining {cls.world_size} workers")  # noqa: G004
+        # Enqueue "None" to all workers to tell them to exit
+        for task_queue in cls.task_queues:
+            task_queue.put(None)
+
+        # Wait for all workers to exit
+        for process in cls.processes:
+            process.join()
+
+        # Clear up the rendezvous file
+        try:
+            os.remove(cls.rdvz_file)
+        except OSError:
+            pass
+
+        logger.info(f"Class {cls.__name__} finished")  # noqa: G004
+        super().tearDownClass()
+
+    def setUp(self) -> None:
+        """
+        Test fixture. Run before each test.
+        """
+        super().setUp()
+
+        # I am the dispatcher
+        self.rank = self.MAIN_PROCESS_RANK
+
+        # If this test class hits an exception in one test, skip the rest of tests
+        if self.__class__.poison_pill:
+            raise unittest.SkipTest(f"Previous test failed, skipping {self.id()}")
+
+        # Enqueue "current test" to all workers
+        for i, task_queue in enumerate(self.task_queues):
+            logger.debug(f"Sending Rank {i}: {self.id()}")  # noqa: G004
+            task_queue.put(self.id())
+
+    def _worker_run_main_wait(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_PROCESS_RANK:
+                logger.debug(f"Waiting for workers to finish {self.id()}")  # noqa: G004
+                # Wait for the workers to finish the test
+                for i, completion_queue in enumerate(self.completion_queues):
+                    rv = completion_queue.get()
+                    if isinstance(rv, unittest.SkipTest):
+                        raise rv
+                    if isinstance(rv, BaseException):
+                        # Hit an exception, re-raise it in the main process.
+                        logger.warning(
+                            f"Detected failure from Rank {i} in: {self.id()}, "  # noqa: G004
+                            f"skipping rest of tests in Test class: {self.__class__.__name__}"  # noqa: G004
+                        )
+                        # Poison rest of tests (because ProcessGroup may be not
+                        # reusable now)
+                        self.__class__.poison_pill = True
+                        raise rv
+
+                    # Success
+                    assert rv == self.id()
+                    logger.debug(
+                        f"Main proc detected rank {i} finished {self.id()}"  # noqa: G004
+                    )
+            else:
+                # Worker just runs the test
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    # The main process spawns N subprocesses that run the test.
+    # Constructor patches current instance test method to
+    # assume the role of the main process and join its subprocesses,
+    # or run the underlying test function.
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self._worker_run_main_wait(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..474bb689f0ad9bcd7ee171b68de22f7752b37e3c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py
@@ -0,0 +1,227 @@
+# mypy: ignore-errors
+
+
+import torch
+
+
+# Functions and classes for describing the dtypes a function supports
+# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
+
+
+# Verifies each given dtype is a torch.dtype
+def _validate_dtypes(*dtypes):
+    for dtype in dtypes:
+        assert isinstance(dtype, torch.dtype)
+    return dtypes
+
+
+# class for tuples corresponding to a PyTorch dispatch macro
+class _dispatch_dtypes(tuple):
+    __slots__ = ()
+
+    def __add__(self, other):
+        assert isinstance(other, tuple)
+        return _dispatch_dtypes(tuple.__add__(self, other))
+
+
+_empty_types = _dispatch_dtypes(())
+
+
+def empty_types():
+    return _empty_types
+
+
+_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
+
+
+def floating_types():
+    return _floating_types
+
+
+_floating_types_and_half = _floating_types + (torch.half,)
+
+
+def floating_types_and_half():
+    return _floating_types_and_half
+
+
+def floating_types_and(*dtypes):
+    return _floating_types + _validate_dtypes(*dtypes)
+
+
+_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
+
+
+def floating_and_complex_types():
+    return _floating_and_complex_types
+
+
+def floating_and_complex_types_and(*dtypes):
+    return _floating_and_complex_types + _validate_dtypes(*dtypes)
+
+
+_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
+
+
+def double_types():
+    return _double_types
+
+
+# NB: Does not contain uint16/uint32/uint64 for BC reasons
+_integral_types = _dispatch_dtypes(
+    (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
+)
+
+
+def integral_types():
+    return _integral_types
+
+
+def integral_types_and(*dtypes):
+    return _integral_types + _validate_dtypes(*dtypes)
+
+
+_all_types = _floating_types + _integral_types
+
+
+def all_types():
+    return _all_types
+
+
+def all_types_and(*dtypes):
+    return _all_types + _validate_dtypes(*dtypes)
+
+
+_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
+
+
+def complex_types():
+    return _complex_types
+
+
+def complex_types_and(*dtypes):
+    return _complex_types + _validate_dtypes(*dtypes)
+
+
+_all_types_and_complex = _all_types + _complex_types
+
+
+def all_types_and_complex():
+    return _all_types_and_complex
+
+
+def all_types_and_complex_and(*dtypes):
+    return _all_types_and_complex + _validate_dtypes(*dtypes)
+
+
+_all_types_and_half = _all_types + (torch.half,)
+
+
+def all_types_and_half():
+    return _all_types_and_half
+
+
+_all_mps_types = (
+    _dispatch_dtypes({torch.float, torch.half, torch.bfloat16}) + _integral_types
+)
+
+
+def all_mps_types():
+    return _all_mps_types
+
+
+def all_mps_types_and(*dtypes):
+    return _all_mps_types + _validate_dtypes(*dtypes)
+
+
+_float8_types = _dispatch_dtypes(
+    (
+        torch.float8_e4m3fn,
+        torch.float8_e4m3fnuz,
+        torch.float8_e5m2,
+        torch.float8_e5m2fnuz,
+    )
+)
+
+
+def float8_types():
+    return _float8_types
+
+
+def float8_types_and(*dtypes):
+    return _float8_types + _validate_dtypes(*dtypes)
+
+
+def all_types_complex_float8_and(*dtypes):
+    return _all_types + _complex_types + _float8_types + _validate_dtypes(*dtypes)
+
+
+def custom_types(*dtypes):
+    """Create a list of arbitrary dtypes"""
+    return _empty_types + _validate_dtypes(*dtypes)
+
+
+# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
+
+
+# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
+def get_all_dtypes(
+    include_half=True,
+    include_bfloat16=True,
+    include_bool=True,
+    include_complex=True,
+    include_complex32=False,
+    include_qint=False,
+) -> list[torch.dtype]:
+    dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
+        include_half=include_half, include_bfloat16=include_bfloat16
+    )
+    if include_bool:
+        dtypes.append(torch.bool)
+    if include_complex:
+        dtypes += get_all_complex_dtypes(include_complex32)
+    if include_qint:
+        dtypes += get_all_qint_dtypes()
+    return dtypes
+
+
+def get_all_math_dtypes(device) -> list[torch.dtype]:
+    return (
+        get_all_int_dtypes()
+        + get_all_fp_dtypes(
+            include_half=device.startswith("cuda"), include_bfloat16=False
+        )
+        + get_all_complex_dtypes()
+    )
+
+
+def get_all_complex_dtypes(include_complex32=False) -> list[torch.dtype]:
+    return (
+        [torch.complex32, torch.complex64, torch.complex128]
+        if include_complex32
+        else [torch.complex64, torch.complex128]
+    )
+
+
+def get_all_int_dtypes() -> list[torch.dtype]:
+    return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
+
+
+def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> list[torch.dtype]:
+    dtypes = [torch.float32, torch.float64]
+    if include_half:
+        dtypes.append(torch.float16)
+    if include_bfloat16:
+        dtypes.append(torch.bfloat16)
+    return dtypes
+
+
+def get_all_qint_dtypes() -> list[torch.dtype]:
+    return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
+
+
+float_to_corresponding_complex_type_map = {
+    torch.float16: torch.complex32,
+    torch.float32: torch.complex64,
+    torch.float64: torch.complex128,
+}
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..74b3cdc78f2d93086cc82886ddf36f5c9cc40184
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py
@@ -0,0 +1,1623 @@
+# mypy: allow-untyped-defs
+# Owner(s): ["oncall: distributed"]
+
+import contextlib
+import os
+import re
+import sys
+import time
+import unittest
+import warnings
+from abc import ABC, abstractmethod
+from collections.abc import Callable
+from contextlib import nullcontext
+from copy import deepcopy
+from enum import auto, Enum
+from functools import wraps
+from typing import Any, cast, no_type_check, Optional, Union
+from unittest import mock
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed._composable import checkpoint
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.fsdp import (
+    CPUOffload,
+    fully_shard,
+    FullyShardedDataParallel as FSDP,
+)
+from torch.distributed.fsdp._common_utils import TrainingState
+from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
+    FSDPParamGroup,
+    RegisterPostBackwardFunction,
+)
+from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+    BackwardPrefetch,
+    MixedPrecision,
+    ShardingStrategy,
+)
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
+from torch.distributed.tensor import distribute_tensor, DTensor, Shard
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    parallelize_module,
+    RowwiseParallel,
+    SequenceParallel,
+)
+from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    MultiThreadedTestCase,
+    run_subtests,
+    TEST_SKIPS,
+)
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    get_cycles_per_ms,
+    set_rng_seed,
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_XPU,
+)
+from torch.utils._triton import has_triton
+
+
+DEVICE_COUNT = 4  # default
+
+if TEST_CUDA:
+    DEVICE_TYPE = "cuda"
+    DISTRIBUTED_BACKEND = "nccl"
+    DEVICE_COUNT = torch.cuda.device_count()
+elif TEST_HPU:
+    DEVICE_TYPE = "hpu:0"
+    DISTRIBUTED_BACKEND = "hccl"
+elif TEST_XPU:
+    DEVICE_TYPE = "xpu"
+    DISTRIBUTED_BACKEND = "xccl"
+    DEVICE_COUNT = torch.xpu.device_count()
+else:
+    DEVICE_TYPE = "cpu"
+    DISTRIBUTED_BACKEND = "gloo"
+    DEVICE_COUNT = 1
+
+
+class FSDPInitMode(Enum):
+    # No FSDP wrapping
+    NO_FSDP = auto()
+    # FSDP recursive wrapping
+    RECURSIVE = auto()
+    # TODO: FSDP non-recursive wrapping
+    # NONRECURSIVE = auto()
+
+
+class DEVICEInitMode(Enum):
+    # Move model to DEVICE before passing to the FSDP constructor
+    DEVICE_BEFORE = auto()
+    # Move model to DEVICE after passing to the FSDP constructor
+    DEVICE_AFTER = auto()
+    # Keep on CPU
+    DEVICE_NEVER = auto()
+
+
+class FSDPTestModel(nn.Module, ABC):
+    """This defines the interface expected from all models used commonly for
+    FSDP unit tests."""
+
+    @abstractmethod
+    def get_input(self, device) -> tuple[torch.Tensor, ...]:
+        """Returns an input for the model as as tuple."""
+        ...
+
+    @abstractmethod
+    def get_loss(self, input, output) -> torch.Tensor:
+        """Returns the loss given the input and output."""
+        ...
+
+    @abstractmethod
+    def run_backward(self, loss) -> None:
+        """Runs the backward pass (e.g. including ``loss.backward()``)."""
+        ...
+
+    @staticmethod
+    @abstractmethod
+    def init(*args: Any, **kwargs: Any) -> nn.Module:
+        """Initializes an instance of this model."""
+        ...
+
+
+def _assert_module_states(
+    model: nn.Module,
+    process_group: dist.ProcessGroup,
+    assert_fn: Callable,
+):
+    """
+    All-gathers module states across ranks and calls ``assert_fn`` on each pair
+    of corresponding states from rank 0 and a nonzero rank. For example, if
+    ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
+    states are equal across ranks.
+    """
+    # Include names for debugging convenience
+    named_module_states = [
+        (param_name, param.detach().cpu())
+        for param_name, param in model.named_parameters()
+    ]
+    named_module_states += [
+        (buffer_name, buffer.detach().cpu())
+        for buffer_name, buffer in model.named_buffers()
+    ]
+    world_size = dist.get_world_size(process_group)
+    olist = [None for _ in range(world_size)]
+    dist.all_gather_object(olist, named_module_states, group=process_group)
+    rank0_states = olist[0]
+    assert rank0_states is not None  # mypy
+    for state in olist[1:]:
+        assert state is not None  # mypy
+        for (_, p1), (_, p2) in zip(rank0_states, state, strict=True):
+            assert_fn(p1, p2)
+
+
+def get_devtype():
+    return torch.device(DEVICE_TYPE)
+
+
+def _zero_model(
+    model: nn.Module,
+    zero_buffers: bool = False,
+    summon_full=True,
+):
+    """Zeros the parameters and optionally buffers of ``model`` in place."""
+    ctx = FSDP.summon_full_params(model) if summon_full else nullcontext()
+    with ctx:
+        for param in model.parameters():
+            with torch.no_grad():
+                param.zero_()
+        if zero_buffers:
+            for buffer in model.buffers():
+                with torch.no_grad():
+                    buffer.zero_()
+
+
+def _get_state_dict(model, cpu_offload=False, half=False):
+    if not cpu_offload:
+        model = model.to(DEVICE_TYPE)
+    if half:
+        model.half()
+
+    return model.state_dict()
+
+
+def subtest_name(test_name_mapping, *args):
+    return "_".join(
+        [test_name_mapping[str(s)] if s is not None else "none" for s in args]
+    )
+
+
+def _broadcast_state_dict(rank, state_dict):
+    # For non-FSDP roots, some parts of the model state on rank 0 may
+    # not be on CPU, so we move everything to CPU to avoid issues like:
+    # https://github.com/pytorch/pytorch/issues/77113.
+    for param_name, param in state_dict.items():
+        if param.device != torch.device("cpu"):
+            state_dict[param_name] = param.cpu()
+
+    olist = [state_dict if rank == 0 else None]
+    dist.broadcast_object_list(olist)
+    state_dict = cast(dict[str, torch.Tensor], olist[0])
+    # Ensure that the state is on DEVICE
+    for param_name in state_dict:
+        state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE)
+    return state_dict
+
+
+def get_full_params(model: nn.Module, recurse: bool = True):
+    """
+    Returns the full unsharded parameters of ``model``. Any FSDP-managed
+    parameters offloaded to CPU are moved to GPU in the returned list.
+
+    Args:
+        recurse (bool): If ``False``, only unshards the parameters immediate to
+            ``model``; if ``True``, recurses through the module hierarchy
+            rooted at ``model``.
+    """
+    with FSDP.summon_full_params(model, recurse=recurse):
+        return deepcopy(list(model.parameters()))
+
+
+def _move_to_device(model: nn.Module, move_to_device: bool):
+    return model.to(DEVICE_TYPE) if move_to_device else model
+
+
+def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs):
+    return model if not wrap_fsdp else FSDP(model, *args, **kwargs)
+
+
+class DummyProcessGroup:
+    def __init__(self, rank: int, size: int):
+        self._rank = rank
+        self._size = size
+
+    def rank(self) -> int:
+        return self._rank
+
+    def size(self) -> int:
+        return self._size
+
+    def allreduce(self, *args, **kwargs):
+        dist_wait = mock.Mock()
+
+        def get_future():
+            future: torch.futures.Future = torch.futures.Future()
+            future.set_result(1)
+            return future
+
+        dist_wait.get_future = get_future
+        return dist_wait
+
+
+class TransformerWithSharedParams(FSDPTestModel):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        device_init_mode: DEVICEInitMode,
+        add_bn: bool,
+        deterministic: bool,
+    ):
+        super().__init__()
+        self.rank = group.rank()
+        self.world_size = group.size()
+        if deterministic:
+            torch.manual_seed(0)
+        d_vocab = 23
+        d_model = 16
+
+        self.embed_tokens = nn.Embedding(d_vocab, d_model)
+        self.transformer = nn.Transformer(
+            d_model=d_model,
+            num_encoder_layers=2,
+            num_decoder_layers=2,
+            dim_feedforward=8,
+            dropout=0.1,
+        )
+        self.output_proj = nn.Linear(d_model, d_vocab)
+
+        # share the embedding and output projection weights
+        self.output_proj.weight = self.embed_tokens.weight
+        self.register_buffer(
+            "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
+        )
+        self.register_buffer(
+            "long_buffer",
+            torch.zeros_like(self.vocab_bias, dtype=torch.long),  # type: ignore[arg-type]
+        )  # type: ignore[arg-type]
+
+        self.bs = 2
+        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
+        if device_init_mode == DEVICEInitMode.DEVICE_BEFORE:
+            self = self.to(DEVICE_TYPE)
+        if deterministic:
+            self.eval()
+
+    def get_input(self, device):
+        torch.manual_seed(1 + self.rank)  # keep everything deterministic
+        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
+        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
+        return (src, tgt)
+
+    def forward(self, src_ids, tgt_ids):
+        src = self.embed_tokens(src_ids)
+        src = src + self.vocab_bias + self.long_buffer.type_as(src)  # type: ignore[operator]
+        tgt = self.embed_tokens(tgt_ids)
+        tgt = self.bn(tgt)
+        x = self.transformer(src, tgt)
+        return self.output_proj(x)
+
+    def get_loss(self, input, output):
+        _, tgt = input
+        return nn.functional.cross_entropy(
+            output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
+        )
+
+    def run_backward(self, loss):
+        loss.backward()
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        add_bn: bool = True,
+    ) -> Union[nn.Module, FSDP]:
+        """
+        Initializes a :class:`TransformerWithSharedParams` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps with
+                top-level FSDP. By default, the top-level FSDP uses the
+                ``ModuleWrapPolicy`` for encoder and decoder layers, but a
+                different auto wrap policy may be specified via
+                ``fsdp_kwargs``.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+            add_bn (bool): Whether to include batch norm in the model.
+        """
+
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            if isinstance(group, tuple):
+                pg = group[0]
+            else:
+                pg = group
+            return TransformerWithSharedParams(
+                pg, device_init_mode, add_bn, deterministic
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Default to the `ModuleWrapPolicy`
+            if "auto_wrap_policy" not in fsdp_kwargs:
+                auto_wrap_policy = ModuleWrapPolicy(
+                    {
+                        TransformerEncoderLayer,
+                        TransformerDecoderLayer,
+                    }
+                )
+            else:
+                auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
+
+            if (
+                "sharding_strategy" in fsdp_kwargs
+                and fsdp_kwargs["sharding_strategy"]
+                in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2}
+                and not isinstance(group, tuple)
+            ):
+                fsdp_pg = None
+            else:
+                fsdp_pg = group
+
+            if isinstance(group, tuple):
+                tformer_pg = group[0]
+            else:
+                tformer_pg = group
+
+            m = TransformerWithSharedParams(
+                tformer_pg, device_init_mode, add_bn, deterministic
+            )
+            fsdp_model = FSDP(
+                m,
+                fsdp_pg,
+                auto_wrap_policy=auto_wrap_policy,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+    def get_ignored_modules(self):
+        return [self.transformer]
+
+
+class NestedWrappedModule(FSDPTestModel):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super().__init__()
+        self.rank = group.rank()
+        self.world_size = group.size()
+        move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+
+        def _maybe_wrap(layer):
+            if wrap_fsdp:
+                return FSDP(layer, group, **fsdp_kwargs)
+            return layer
+
+        if deterministic:
+            torch.manual_seed(0)
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(8, 4), move_to_device),
+            _maybe_wrap(
+                nn.Sequential(
+                    _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)),
+                    _move_to_device(nn.Linear(16, 16), move_to_device),
+                ),
+            ),
+            _maybe_wrap(_move_to_device(nn.Linear(16, 4), move_to_device)),
+            _move_to_device(nn.Linear(4, 8), move_to_device),
+        )
+
+    def get_input(self, device):
+        torch.manual_seed(1 + self.rank)  # keep everything deterministic
+        return (torch.rand(4, 8, device=device),)
+
+    def forward(self, x):
+        return self.module(x)
+
+    def get_loss(self, input, output):
+        loss = output.sum()
+        return loss
+
+    def run_backward(self, loss):
+        loss.backward()
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ) -> nn.Module:
+        """
+        Initializes a :class:`NestedWrappedModule` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
+                modules with FSDP but not the top-level module. The model may
+                later be wrapped with a top-level FSDP external to this method
+                if desired.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+        """
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return NestedWrappedModule(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Does not wrap with top-level FSDP
+            fsdp_model = NestedWrappedModule(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ):
+        """
+        Initializes a :class:`NestedWrappedModule` instance, but unlike
+        :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
+        wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
+        policy.
+        """
+        model = super(
+            AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule
+        ).init(
+            group=group,
+            fsdp_init_mode=FSDPInitMode.NO_FSDP,
+            device_init_mode=device_init_mode,
+            fsdp_kwargs=fsdp_kwargs,
+            deterministic=deterministic,
+        )
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return model
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            fsdp_kwargs = fsdp_kwargs or {}
+            fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+
+
+class NonUniformReqGradNWM(NestedWrappedModule):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super(NestedWrappedModule, self).__init__()
+        # This `__init__` only differs from `NestedWrappedModule.__init__` in that
+        # the last two `nn.Linear` layers are FSDP wrapped in a `nn.Sequential`
+        # container. This arrangement results in all elements of the last two parameters
+        # residing on a single rank. Freezing all parameters except those two allows us
+        # to verify that `ShardedGradScaler` accommodates situations where some ranks
+        # have no (non-zero sized) parameter shards.
+        self.rank = group.rank()
+        self.world_size = group.size()
+        move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+
+        def _maybe_wrap(layer):
+            if wrap_fsdp:
+                return FSDP(layer, group, **fsdp_kwargs)
+            return layer
+
+        if deterministic:
+            torch.manual_seed(0)
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(8, 4), move_to_device),
+            _maybe_wrap(
+                nn.Sequential(
+                    _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)),
+                    _move_to_device(nn.Linear(16, 16), move_to_device),
+                ),
+            ),
+            _maybe_wrap(
+                nn.Sequential(
+                    _move_to_device(nn.Linear(16, 4), move_to_device),
+                    _move_to_device(nn.Linear(4, 8), move_to_device),
+                ),
+            ),
+        )
+
+    @staticmethod
+    def _set_nonuniform_req_grad(model, req_grad_mask) -> None:
+        for n, p in model.named_parameters():
+            if not re.match(req_grad_mask, n):
+                p.requires_grad_(False)
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ):
+        """
+        Initializes a :class:`NestedWrappedModule` instance, but unlike
+        :meth:`NestedWrappedModule.init`, it wraps a second :class:`torch.nn.Sequential`
+        container to enable the desired non-uniform ``requires_grad``
+        ``use_orig_params=True`` tests. For both ``RECURSIVE`` and ``NO_FSDP``
+        init modes, freezes all parameters except the last two to validate
+        ``ShardedGradScaler`` support for ranks with no (non-zero sized) local shards in
+        FSDP ``use_orig_params=True`` mode.
+        """
+        # The parameters that should remain unfrozen are in `module.2.1`. The regex
+        # pattern below matches the relevant parameter names both with and without
+        # an interstitial FSDP module indicator (`_fsdp_wrapped_module`) present.
+        req_grad_pattern = re.compile(r"module\.2.*\.1.*")
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            ddp_model = NonUniformReqGradNWM(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+            )
+            NonUniformReqGradNWM._set_nonuniform_req_grad(ddp_model, req_grad_pattern)
+            return ddp_model
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            if fsdp_kwargs is None:
+                fsdp_kwargs = {}
+            fsdp_model = NonUniformReqGradNWM(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            NonUniformReqGradNWM._set_nonuniform_req_grad(fsdp_model, req_grad_pattern)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class ModuleWithDelay(FSDPTestModel):
+    """This class wraps a :class:`FSDPTestModel` to optionally add a delay
+    after computing the loss and/or before the gradient reduction."""
+
+    def __init__(
+        self,
+        module: nn.Module,
+        delay_after_loss_ms: int,
+        delay_before_reduction_ms: int,
+    ):
+        super().__init__()
+        self.delay_after_loss_ms = delay_after_loss_ms
+        self.delay_before_reduction_ms = delay_before_reduction_ms
+        self.module = module
+
+    def get_input(self, device):
+        return self.module.get_input(device)  # type: ignore[operator]
+
+    def forward(self, x):
+        return self.module(x)
+
+    def get_loss(self, input, output):
+        loss = self.module.get_loss(input, output)  # type: ignore[operator]
+        if self.delay_after_loss_ms > 0:
+            if TEST_HPU or TEST_XPU:
+                time.sleep(self.delay_after_loss_ms / 1000)
+            elif TEST_CUDA:
+                torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
+
+        return loss
+
+    def run_backward(self, loss):
+        orig_reduce_scatter = torch.distributed.reduce_scatter_tensor
+
+        def _delayed_reduce_scatter(*args, **kwargs):
+            if self.delay_before_reduction_ms > 0:
+                if TEST_CUDA:
+                    torch.cuda._sleep(
+                        int(self.delay_before_reduction_ms * get_cycles_per_ms())
+                    )
+                elif TEST_HPU or TEST_XPU:
+                    time.sleep(self.delay_before_reduction_ms / 1000)
+            return orig_reduce_scatter(*args, **kwargs)
+
+        with mock.patch(
+            "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter
+        ):
+            self.module.run_backward(loss)  # type: ignore[operator]
+
+    @staticmethod
+    def init(
+        module_class: type[FSDPTestModel],
+        *model_args: Any,
+        delay_after_loss_ms: int,
+        delay_before_reduction_ms: int,
+        **model_kwargs: Any,
+    ):
+        """
+        Args:
+            module_class (Type[FSDPTestModel]): Wrapped module class to which
+                to add delays.
+            model_args: Positional arguments forwarded to the ``module_class``
+                ``init()``.
+            delay_after_loss_ms (int): Delay after computing the loss/before
+                the optimizer step (in ms).
+            delay_before_reduction_ms (int): Delay before reduce-scattering
+                gradients (in ms).
+            model_kwargs: Keyword arguments forwarded to the ``module_class``
+                ``init()``.
+        """
+        return ModuleWithDelay(
+            module_class.init(*model_args, **model_kwargs),
+            delay_after_loss_ms,
+            delay_before_reduction_ms,
+        )
+
+
+class NestedWrappedModuleWithDelay(ModuleWithDelay):
+    @staticmethod
+    def init(  # type: ignore[override]
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode = DEVICEInitMode.DEVICE_AFTER,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        delay_after_loss_ms: int = 0,
+        delay_before_reduction_ms: int = 0,
+    ):
+        return ModuleWithDelay.init(
+            NestedWrappedModule,
+            group=group,
+            fsdp_init_mode=fsdp_init_mode,
+            device_init_mode=device_init_mode,
+            fsdp_kwargs=fsdp_kwargs,
+            deterministic=deterministic,
+            delay_after_loss_ms=delay_after_loss_ms,
+            delay_before_reduction_ms=delay_before_reduction_ms,
+        )
+
+
+class DummyDDP(nn.Module):
+    def __init__(self, module):
+        super().__init__()
+        self.module = module
+
+    def forward(self, *args, **kwargs):
+        return self.module(*args, **kwargs)
+
+
+class MixtureOfExperts(NestedWrappedModule):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        delay_before_free_ms: int,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super().__init__(
+            group=group,
+            wrap_fsdp=wrap_fsdp,
+            device_init_mode=device_init_mode,
+            deterministic=deterministic,
+        )
+        self.group = group
+        self.delay_before_free_ms = delay_before_free_ms
+        self.wrap_fsdp = wrap_fsdp
+        self.move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+        if deterministic:
+            # Give each rank different expert parameters
+            torch.manual_seed(42 + self.rank)
+        d_expert = 23
+        d_shared = 12
+        d_input = 8
+        expert = _move_to_device(nn.Linear(d_expert, d_shared), self.move_to_device)
+
+        self.num_expert_params = sum(p.numel() for p in expert.parameters())
+        for p in expert.parameters():
+            p.expert = True  # type: ignore[attr-defined]
+
+        if deterministic:
+            # Keep all other parameters the same across ranks
+            torch.manual_seed(0)
+
+        shared = _move_to_device(nn.Linear(d_shared, d_expert), self.move_to_device)
+
+        if wrap_fsdp:
+            # we create a process group of size 1 for the expert params
+            expert_group = torch.distributed.new_group(
+                [group.rank()]
+            )  # world size 1 means no shard
+            expert = FSDP(expert, expert_group, **fsdp_kwargs)  # type: ignore[assignment]
+            shared = FSDP(shared, group, **fsdp_kwargs)  # type: ignore[assignment]
+
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(d_input, d_shared), self.move_to_device),
+            shared,
+            expert,
+            _move_to_device(nn.Linear(d_shared, d_input), self.move_to_device),
+        )
+
+    def forward(self, x):
+        if self.delay_before_free_ms > 0:
+            expert = self.module[2]
+            if isinstance(expert, FSDP):
+                orig_reshard = torch.distributed.fsdp._runtime_utils._reshard
+
+                def _delayed_reshard(*args, **kwargs):
+                    if TEST_CUDA:
+                        torch.cuda._sleep(
+                            int(self.delay_before_free_ms * get_cycles_per_ms())
+                        )
+                    elif TEST_HPU or TEST_XPU:
+                        time.sleep(self.delay_before_free_ms / 1000)
+
+                    return orig_reshard(*args, **kwargs)
+
+                # This patch covers any `import torch..._reshard` uses.
+                with mock.patch(
+                    "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard
+                ):
+                    return self.module(x)
+
+        return self.module(x)
+
+    def run_backward(self, loss):
+        loss.backward()
+        # Manually reduce gradients if not wrapped in FullyShardedDataParallel
+        if not self.wrap_fsdp:
+            with torch.no_grad():
+                for p in self.parameters():
+                    if hasattr(p, "expert"):
+                        continue  # these params don't need grad reduction
+                    if p.grad is not None:
+                        p.grad.div_(self.world_size)
+                        torch.distributed.all_reduce(p.grad, group=self.group)
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        delay_before_free_ms: int = 0,
+    ):
+        """
+        Initializes a :class:`MixtureOfExperts` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
+                modules with FSDP, including the expert and shared layers, but
+                not the top-level module. The model may later be wrapped with a
+                top-level FSDP external to this method if desired.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+            delay_before_free_ms (int): Delay before resharding expert
+                parameters in the forward pass (in ms).
+        """
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return MixtureOfExperts(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                delay_before_free_ms=delay_before_free_ms,
+                deterministic=deterministic,
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Does not wrap with top-level FSDP
+            fsdp_model = MixtureOfExperts(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                delay_before_free_ms=delay_before_free_ms,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class MLP(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        device: Optional[torch.device] = None,
+        *,
+        bias: bool = True,
+        with_buffer: bool = False,
+        dim_multiplier: int = 4,
+    ):
+        super().__init__()
+        self.in_proj = nn.Linear(dim, dim_multiplier * dim, device=device, bias=bias)
+        self.out_proj = nn.Linear(dim_multiplier * dim, dim, device=device, bias=bias)
+        if with_buffer:
+            self.register_buffer("buffer", torch.randn((dim,), device=device))
+        else:
+            self.buffer = None
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        z = self.in_proj(x)
+        z = F.relu(z)
+        z = self.out_proj(z)
+        z = F.relu(z)
+        if self.buffer is not None:
+            z = z + self.buffer
+        return z
+
+    def reset_parameters(self):
+        if self.buffer is not None:
+            torch.nn.init.normal_(self.buffer)
+
+
+class MLPStack(nn.Sequential):
+    def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
+        modules: list[nn.Module] = [
+            # Use multiplier of 3 to exercise uneven case
+            MLP(mlp_dim, dim_multiplier=3),
+            MLP(mlp_dim),
+            MLP(mlp_dim, dim_multiplier=3),
+        ]
+        if with_seq_parallel:
+            modules.append(nn.LayerNorm(mlp_dim, bias=False))
+        super().__init__(*modules)
+        self.with_seq_parallel = with_seq_parallel
+
+    def parallelize(
+        self,
+        tp_mesh: DeviceMesh,
+        dp_mesh: DeviceMesh,
+        use_activation_checkpointing: bool,
+        **fsdp_kwargs,
+    ) -> "MLPStack":
+        parallelize_plan = {
+            # Pass `use_local_output=False` to keep as DTensor to preserve
+            # uneven activation dims
+            "0.in_proj": ColwiseParallel(use_local_output=False),
+            "0.out_proj": RowwiseParallel(use_local_output=False),
+            "1.in_proj": ColwiseParallel(use_local_output=False),
+            "1.out_proj": RowwiseParallel(use_local_output=False),
+            "2.in_proj": ColwiseParallel(use_local_output=False),
+            "2.out_proj": RowwiseParallel(output_layouts=Shard(1))
+            if self.with_seq_parallel
+            else RowwiseParallel(),
+        }
+        if self.with_seq_parallel:
+            parallelize_plan["3"] = SequenceParallel(sequence_dim=1)
+        parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan)
+        for module in self:
+            if isinstance(module, nn.LayerNorm):
+                continue
+            if use_activation_checkpointing:
+                checkpoint(module)
+            fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
+        fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
+        return self
+
+
+class DoubleLinear(nn.Module):
+    """
+    This can be used for returning multiple outputs from a module
+    (``use_second_linear=True``) or for having an unused module (``False``).
+    """
+
+    def __init__(self, dim: int, use_second_linear: bool = True):
+        super().__init__()
+        self.lin1 = nn.Linear(dim, dim)
+        self.lin2 = nn.Linear(dim, dim)
+        self.relu = nn.ReLU()
+        self.use_second_linear = use_second_linear
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
+        if self.use_second_linear:
+            return self.relu(self.lin1(x)), self.relu(self.lin2(x))
+        return self.relu(self.lin1(x))
+
+
+# NOTE: For these patch methods, if we want safety under multi-threading (e.g.
+# when using multi-threaded process group), then we want:
+# (1) a barrier immediately after reading the original value to ensure that all
+# threads see the same original value
+# (2) a barrier immediately before restoring the original value to ensure that
+# all threads use the patched value inside the context
+@contextlib.contextmanager
+def patch_all_gather(new_all_gather_into_tensor: Callable):
+    orig_all_gather = dist.all_gather_into_tensor
+    dist.barrier()
+    dist.all_gather_into_tensor = new_all_gather_into_tensor
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.all_gather_into_tensor = orig_all_gather
+
+
+@contextlib.contextmanager
+def patch_foreach_all_gather(new_foreach_all_gather: Callable):
+    orig_foreach_all_gather = (
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
+    )
+    dist.barrier()
+    torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
+        new_foreach_all_gather
+    )
+    try:
+        yield
+    finally:
+        dist.barrier()
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
+            orig_foreach_all_gather
+        )
+
+
+@contextlib.contextmanager
+def patch_foreach_reduce(new_foreach_reduce: Callable):
+    orig_foreach_foreach_reduce = (
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
+    )
+    dist.barrier()
+    torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
+        new_foreach_reduce
+    )
+    try:
+        yield
+    finally:
+        dist.barrier()
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
+            orig_foreach_foreach_reduce
+        )
+
+
+@contextlib.contextmanager
+def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
+    orig_reduce_scatter = dist.reduce_scatter_tensor
+    dist.barrier()
+    dist.reduce_scatter_tensor = new_reduce_scatter_tensor
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.reduce_scatter_tensor = orig_reduce_scatter
+
+
+@contextlib.contextmanager
+def patch_all_reduce(new_all_reduce: Callable):
+    orig_all_reduce = dist.all_reduce
+    dist.barrier()
+    dist.all_reduce = new_all_reduce
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.all_reduce = orig_all_reduce
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_unshard(new_unshard: Callable):
+    orig_unshard = FSDPParamGroup.unshard
+    dist.barrier()
+    FSDPParamGroup.unshard = new_unshard
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.unshard = orig_unshard
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_reshard(new_reshard: Callable):
+    orig_reshard = FSDPParamGroup.reshard
+    dist.barrier()
+    FSDPParamGroup.reshard = new_reshard
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.reshard = orig_reshard
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_post_backward(new_post_backward: Callable):
+    orig_post_backward = FSDPParamGroup.post_backward
+    dist.barrier()
+    FSDPParamGroup.post_backward = new_post_backward
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.post_backward = orig_post_backward
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_register_post_backward_hook_backward(new_backward: Callable):
+    orig_backward = RegisterPostBackwardFunction.backward
+    dist.barrier()
+    RegisterPostBackwardFunction.backward = new_backward
+    try:
+        yield
+    finally:
+        dist.barrier()
+        RegisterPostBackwardFunction.backward = orig_backward
+
+
+def reduce_scatter_with_assert(
+    cls,
+    orig_reduce_scatter: Callable,
+    assert_fn: Callable,  # `assert_fn(output: Tensor)`
+    *args: Any,
+    **kwargs: Any,
+):
+    if len(args) > 0:
+        output = args[0]
+    elif "output" in kwargs:
+        output = kwargs["output"]
+    else:
+        raise AssertionError(
+            f"Cannot get reduce-scatter output from\nargs: {args}\nkwargs: {kwargs}"
+        )
+    assert_fn(output)
+    return orig_reduce_scatter(*args, **kwargs)
+
+
+def check_sharded_parity(
+    cls,  # unit test class
+    replicated_module: nn.Module,
+    sharded_module: nn.Module,
+    prefixes_to_ignore: tuple[str, ...] = (),
+):
+    for (replicated_name, replicated_param), (sharded_name, sharded_param) in zip(
+        replicated_module.named_parameters(),
+        sharded_module.named_parameters(),
+        strict=True,
+    ):
+        clean_sharded_name = sharded_name
+        for prefix in prefixes_to_ignore:
+            clean_sharded_name = clean_sharded_name.replace(prefix, "")
+        cls.assertEqual(replicated_name, clean_sharded_name)
+        cls.assertIsInstance(sharded_param, DTensor)
+        assert isinstance(sharded_param, DTensor)  # mypy
+        mesh, placements = sharded_param.device_mesh, sharded_param.placements
+        if tuple(placements) == (Shard(0), Shard(0)):
+            raise AssertionError(
+                "FSDP's (Shard(0), Shard(0)) layout differs from distribute_tensor(), "
+                "so we cannot check for equality using it"
+            )
+        sharded_ref_param = distribute_tensor(replicated_param, mesh, placements)
+        cls.assertEqual(sharded_param.to_local(), sharded_ref_param.to_local())
+        if replicated_param.grad is None:
+            cls.assertIsNone(sharded_param.grad)
+            continue
+        cls.assertIsNotNone(sharded_param.grad)
+        sharded_ref_grad = distribute_tensor(replicated_param.grad, mesh, placements)
+        cls.assertIsInstance(sharded_param.grad, DTensor)
+        assert isinstance(sharded_param.grad, DTensor)  # mypy
+        cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
+
+
+@unittest.skipIf(TEST_XPU, "not-support-multithread")
+class FSDPTestMultiThread(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return DEVICE_COUNT
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+    def perThreadSetUp(self):
+        torch._dynamo.reset()
+
+    def perThreadTearDown(self):
+        torch._dynamo.reset()
+
+
+class FSDPTest(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`,
+        # which can cause unit test flakiness:
+        # https://github.com/pytorch/pytorch/issues/90848
+        os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
+        self._spawn_processes()
+
+    @property
+    def world_size(self):
+        return DEVICE_COUNT
+
+    @property
+    def process_group(self):
+        return dist.distributed_c10d._get_default_group()
+
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        # Overriding base test class: do not auto destroy PG upon exit.
+        return False
+
+    @property
+    def init_method(self):
+        return f"{FILE_SCHEMA}{self.file_name}"
+
+    def _check_cpu_offload(self, fsdp_model, cpu_offload):
+        self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
+
+    def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
+        self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
+
+    def _check_forward_prefetch(self, fsdp_model, forward_prefetch):
+        self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch)
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+    @classmethod
+    def _run(cls, rank, test_name, file_name, pipe, **kwargs):  # type: ignore[override]
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        fake_pg = kwargs.get("fake_pg", False)
+
+        print(f"dist init r={self.rank}, world={self.world_size}")
+        if torch.accelerator.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        # Specify gloo backend to make 'init_process_group()' succeed,
+        # Actual tests will be skipped if there is no enough GPUs.
+        try:
+            if fake_pg:
+                store = torch.testing._internal.distributed.fake_pg.FakeStore()
+                dist.init_process_group(
+                    backend="fake",
+                    world_size=self.world_size,
+                    rank=rank,
+                    store=store,
+                )
+            else:
+                dist.init_process_group(
+                    init_method=self.init_method,
+                    backend=DISTRIBUTED_BACKEND,
+                    world_size=int(self.world_size),
+                    rank=self.rank,
+                )
+        except RuntimeError as e:
+            if "recompile" in e.args[0]:
+                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
+
+            raise
+
+        device_ids = None
+        device_id = self.rank % DEVICE_COUNT
+        if TEST_CUDA or TEST_XPU:
+            torch.accelerator.set_device_index(device_id)
+        device_ids = [device_id]
+
+        # Execute barrier prior to running test to ensure that every process
+        # has finished initialization and that the following test
+        # immediately exiting due to a skip doesn't cause flakiness.
+        dist.barrier(device_ids=device_ids)
+
+        torch._dynamo.reset()
+        set_rng_seed()
+        self.run_test(test_name, pipe)
+        torch._dynamo.reset()
+
+        dist.barrier(device_ids=device_ids)
+
+        dist.destroy_process_group()
+
+    def _train_for_several_steps(
+        self,
+        model: nn.Module,
+        num_steps: int,
+        autocast: bool,
+        lr: float = 0.01,
+        fsdp_cpu_offload: Optional[CPUOffload] = None,
+        save_model: bool = False,
+        mixed_precision: Optional[MixedPrecision] = None,
+        enable_sharded_grad_scaler: bool = False,
+        use_pure_fp16: bool = False,
+        sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
+    ):
+        cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
+
+        model_device = next(model.parameters()).device
+        if sharded_grad_scaler_kwargs is None:
+            sharded_grad_scaler_kwargs = {}
+        sharded_grad_scaler = ShardedGradScaler(
+            enabled=enable_sharded_grad_scaler, **sharded_grad_scaler_kwargs
+        )
+        # use SGD with momentum instead of Adam, since Adam is scale invariant
+        # and this makes it bad for tests
+        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
+        for _ in range(num_steps):
+            optim.zero_grad()
+            with torch.amp.autocast(DEVICE_TYPE, enabled=autocast):
+                # Inputs always cuda regardless of cpu offloading, or model.device
+                input = model.module.get_input(torch.device(DEVICE_TYPE))  # type: ignore[operator, union-attr]
+                if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)):
+                    if isinstance(input, torch.Tensor):
+                        input = input.half()
+                    else:
+                        input = tuple(x.half() for x in input)
+                output = model(*input)
+                # Post-forward, if CPU offloading model param should be on CPU.
+                if (
+                    cpu_offload_params
+                    and isinstance(model, FSDP)
+                    # If not resharding after forward, the parameters are still
+                    # exposed as unsharded views into the GPU flat parameter
+                    and model.sharding_strategy
+                    not in NO_RESHARD_AFTER_FORWARD_STRATEGIES
+                ):
+                    for p in model.parameters():
+                        # Params should always be on CPU
+                        self.assertEqual(p.device, torch.device("cpu"))
+
+                loss = model.module.get_loss(input, output).to(model_device)  # type: ignore[operator, union-attr]
+            loss = sharded_grad_scaler.scale(loss)
+
+            if not mixed_precision and not use_pure_fp16:
+                assert loss.dtype == torch.float32, (
+                    "loss data type should be float32, as the original \
+                    parameter data type is float32."
+                )
+            else:
+                if use_pure_fp16:
+                    self.assertEqual(loss.dtype, torch.float16)
+                # FSDP loss is fp16, DDP AMP loss is fp32
+                elif isinstance(model, FSDP):
+                    assert mixed_precision is not None  # mypy
+                    self.assertEqual(loss.dtype, mixed_precision.param_dtype)
+                else:
+                    self.assertEqual(loss.dtype, torch.float32)
+            model.module.run_backward(loss)  # type: ignore[operator, union-attr]
+            # Post-backward, if CPU offloading model params should be on CPU.
+            if cpu_offload_params and isinstance(model, FSDP):
+                for p in model.parameters():
+                    # Params should always be on CPU
+                    self.assertEqual(p.device, torch.device("cpu"))
+            # Unscale the gradients and step
+            sharded_grad_scaler.step(optim)
+            # Update the scale factor
+            sharded_grad_scaler.update()
+            # if save_model, simulate save + load.
+            if save_model:
+                state_dict = {k: v.clone() for k, v in model.state_dict().items()}
+                # Zero params, if save/load state_dict did not work properly, this
+                # would break the parity test with DDP.
+                _zero_model(model)
+                model.load_state_dict(state_dict)
+
+        if isinstance(model, FSDP):
+            model._assert_state(TrainingState.IDLE)
+        return loss.detach()  # type: ignore[possibly-undefined]
+
+    def _test_fsdp_parity(
+        self,
+        model_class: type[FSDPTestModel],
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        ref_init_fn: Optional[Callable] = None,
+        num_iters: int = 2,
+        save_model: bool = True,
+        cpu_offload: CPUOffload = CPUOffload(),
+        backward_prefetch: Optional[BackwardPrefetch] = None,
+        sharding_strategy: Optional[ShardingStrategy] = None,
+        mixed_precision: Optional[MixedPrecision] = None,
+        forward_prefetch: bool = False,
+        use_orig_params: bool = False,
+        enable_sharded_grad_scaler: bool = False,
+        use_pure_fp16: bool = False,
+        init_kwargs: Optional[dict[str, Any]] = None,
+        sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
+        **fsdp_kwargs,
+    ):
+        """
+        Tests FSDP training against a reference, which defaults to DDP but
+        may be customized with ``ref_init_fn``.
+
+        Args:
+            model_class (Type[FSDPTestModel]): A model class that inherits from
+                ``FSDPTestModel``, which defines the expected interface.
+            fsdp_init_mode (FSDPInitMode): The mode to initialize the
+                FSDP-wrapped model. This should not be ``NO_FSDP``.
+            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
+                non-wrapped model to construct the reference model, where this
+                wrapper should provide data parallel semantics. If ``None``,
+                then the callable defaults to the DDP constructor.
+        """
+        assert fsdp_init_mode != FSDPInitMode.NO_FSDP, (
+            "Expects an FSDP init mode that wraps with FSDP"
+        )
+        if init_kwargs is None:
+            init_kwargs = {}
+        lr = 1e-2
+        rank = self.process_group.rank()
+        # Establish reference behavior with DDP
+        model = model_class.init(
+            self.process_group,
+            FSDPInitMode.NO_FSDP,
+            DEVICEInitMode.DEVICE_BEFORE,
+            deterministic=True,
+            **init_kwargs,
+        )
+        if ref_init_fn is None:
+            if TEST_HPU:
+                ref_model = DDP(
+                    model, device_ids=[DEVICE_TYPE], output_device=DEVICE_TYPE
+                )
+            else:
+                ref_model = DDP(model, device_ids=[rank], output_device=rank)
+        else:
+            ref_model = ref_init_fn(model)
+        if use_pure_fp16:
+            ref_model = ref_model.half()
+        ref_loss = self._train_for_several_steps(
+            ref_model,
+            num_iters,
+            autocast=mixed_precision is not None,
+            lr=lr,
+            fsdp_cpu_offload=cpu_offload,
+            mixed_precision=mixed_precision,
+            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
+            use_pure_fp16=use_pure_fp16,
+            sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
+        )
+        ddp_params = list(ref_model.parameters())
+        # Check against FSDP behavior
+        fsdp_kwargs.update(
+            {
+                "cpu_offload": cpu_offload,
+                "backward_prefetch": backward_prefetch,
+                "sharding_strategy": sharding_strategy,
+                "mixed_precision": mixed_precision,
+                "forward_prefetch": forward_prefetch,
+                "use_orig_params": use_orig_params,
+            }
+        )
+        try:
+            fsdp_model = model_class.init(
+                self.process_group,
+                fsdp_init_mode,
+                device_init_mode,
+                fsdp_kwargs,
+                deterministic=True,
+                **init_kwargs,
+            )
+        except Exception as e:
+            raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
+        if not isinstance(fsdp_model, FSDP):
+            # Enforce that we wrap with top-level FSDP since we are comparing
+            # assuming a data parallel reference and some test models may not
+            # do so in their `init()` method
+            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
+        if use_pure_fp16:
+            # Change the model parameter dtype after FSDP initialization
+            fsdp_model = fsdp_model.half()
+        if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+            fsdp_model = fsdp_model.to(DEVICE_TYPE)
+        offload_params = cpu_offload is not None and cpu_offload.offload_params
+        # Offloading parameters with `DEVICE_AFTER` should raise an error during
+        # lazy initialization due to the parameter devices not being CPU;
+        # otherwise, all parameter devices should be CPU
+        expects_device_error = (
+            offload_params and device_init_mode == DEVICEInitMode.DEVICE_AFTER
+        )
+        expects_cpu_device = (
+            offload_params and device_init_mode != DEVICEInitMode.DEVICE_AFTER
+        )
+        if expects_cpu_device:
+            cpu_device = torch.device("cpu")
+            for param in fsdp_model.parameters():
+                self.assertEqual(param.device, cpu_device)
+        context = (
+            self.assertRaisesRegex(
+                RuntimeError,
+                "An FSDP-managed module with parameter CPU offloading enabled "
+                f"has parameters on {DEVICE_TYPE}",
+            )
+            if expects_device_error
+            else nullcontext()
+        )
+        with context:
+            fsdp_loss = self._train_for_several_steps(
+                fsdp_model,
+                num_iters,
+                autocast=False,
+                lr=lr,
+                fsdp_cpu_offload=cpu_offload,
+                save_model=save_model,
+                mixed_precision=mixed_precision,
+                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
+                use_pure_fp16=use_pure_fp16,
+                sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
+            )
+        # No need to check for parameter and loss parity if expecting an error
+        if expects_device_error:
+            return
+        # Check parameter devices are CPU if offloading to CPU before calling
+        # `get_full_params()`, which will cast the parameters to FP32
+        if offload_params:
+            cpu_device = torch.device("cpu")
+            for param in fsdp_model.parameters():
+                self.assertEqual(param.device, cpu_device)
+            fsdp_loss = fsdp_loss.to(DEVICE_TYPE)
+        fsdp_unsharded_params = get_full_params(fsdp_model)
+        # Do not check dtype since the reference DDP loss may not be the same
+        # dtype as the FSDP loss in the case of mixed precision
+        torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
+        # Do not check for parameter parity if using mixed precision since (1)
+        # the DDP parameters are in FP16 (from `half()`) while the FSDP
+        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
+        # the optimizer in FP16 while FSDP runs it in FP32
+        # TODO: Disable checking the parameters for pure FP16 due to floating
+        # point inaccuracy. Note that this means that the backward pass is not
+        # checked: https://github.com/pytorch/pytorch/issues/90784
+        if mixed_precision is None and not use_pure_fp16:
+            self.assertEqual(
+                ddp_params,
+                fsdp_unsharded_params,
+                exact_device=True,
+                msg="FSDP did not match DDP",
+            )
+
+
+def compiled_fsdp_test(compile_compute_on_module: Optional[type] = None):
+    def fully_shard_with_compiled_compute(*args, **kwargs):
+        torch.distributed.fsdp.fully_shard(*args, **kwargs)  # type: ignore[operator]
+        if compile_compute_on_module is None or isinstance(
+            args[0], compile_compute_on_module
+        ):
+            args[0].compile()
+
+    class FullyShardMode(Enum):
+        EAGER = auto()
+        COMPILED_COMPUTE = auto()
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            original_fully_shard: Any = torch.distributed.fsdp.fully_shard
+            for mode in FullyShardMode:
+                if mode != FullyShardMode.EAGER and not has_triton():
+                    warnings.warn(
+                        "Inductor on GPU needs Triton and recent GPU arch", stacklevel=2
+                    )
+                    continue
+                # barrier to ensure thread reading the same value
+                original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks
+                original_compile_threads = torch._inductor.config.compile_threads
+                torch.distributed.barrier()
+
+                if mode == FullyShardMode.EAGER:
+                    fully_shard_patch = original_fully_shard
+                elif mode == FullyShardMode.COMPILED_COMPUTE:
+                    torch._dynamo.config.skip_fsdp_hooks = True
+                    torch._inductor.config.compile_threads = 1
+                    fully_shard_patch = fully_shard_with_compiled_compute  # type: ignore[assignment]
+                else:
+                    raise NotImplementedError(
+                        f"Need to implement FullyShardMode={mode}"
+                    )
+
+                # fully_shard is imported as a global
+                # through `from ... import fully_shard`
+                func.__globals__[original_fully_shard.__name__] = fully_shard_patch
+                func(*args, **kwargs)
+                # other threads use patched func before this thread restores
+                torch.distributed.barrier()
+                func.__globals__[original_fully_shard.__name__] = original_fully_shard
+                torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
+                torch._inductor.config.compile_threads = original_compile_threads
+
+        return wrapper
+
+    return decorator
+
+
+class SkipModule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.lin = nn.Linear(10, 10, bias=False)
+
+    def forward(self, x):
+        return self.lin(x)
+
+
+class NestedLinear(nn.Module):
+    def __init__(self, fsdp_wrap):
+        super().__init__()
+        if fsdp_wrap:
+            self.nested_linear = wrap(nn.Linear(10, 10, bias=False).to(DEVICE_TYPE))
+        else:
+            self.nested_linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE)
+
+    def forward(self, x):
+        return self.nested_linear(x)
+
+
+class SkipModel(nn.Module):
+    def __init__(self, double_nest):
+        super().__init__()
+        self.linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE)
+        self.linear_skip = SkipModule().to(DEVICE_TYPE)
+        self.nested_linear = wrap(
+            NestedLinear(fsdp_wrap=double_nest), device_id=DEVICE_TYPE
+        )
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = self.linear_skip(x)
+        x = self.nested_linear(x)
+        return x
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac6e851d7e28b0466f9b49862f1df78781c2a461
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py
@@ -0,0 +1,323 @@
+# mypy: ignore-errors
+
+# Torch
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+import torch.jit.quantized
+
+# Testing utils
+from torch.testing._internal.common_dtype import floating_and_complex_types_and
+from torch.testing._internal.common_utils import TestCase, \
+    freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
+from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401
+
+# Standard library
+from itertools import chain
+from typing import Union
+from torch._C import TensorType
+
+import io
+
+def check_output_types(self, func, ref_outputs, args, kwargs):
+    graph = getattr(func, 'last_graph', None)
+    types = [o.type() for o in graph.outputs()]
+    self.assertTrue(len(types) == 1)
+    t = types[0]
+    torch._C._jit_assert_is_instance(ref_outputs, t)
+
+# Test names in this set are only checked for a single derivative
+nn_functional_single_grad = frozenset('test_nn_' + name for name in [
+    'pdist',
+    'multilabel_margin_loss',
+    'max_unpool3d',
+    'multi_margin_loss',
+    'binary_cross_entropy',
+    'binary_cross_entropy_size_average',
+    'ctc_loss',
+    'grid_sample',
+])
+
+def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
+                            allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
+    """Verifies a function performs identically to some reference implementation.
+
+    Commonly, this is used to verify that a JIT implementation
+    (output_func) matches the behavior of the eager implementation
+    (reference_func).
+    """
+    kwargs = kwargs if kwargs else {}
+
+    def allSum(vs):
+        if isinstance(vs, torch.Tensor):
+            vs = (vs,)
+        return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
+                   for i, v in enumerate(vs)
+                   if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
+
+    def clone_tensor(t, preserve_requires_grad):
+        require_grad = preserve_requires_grad and t.requires_grad
+        return t.detach().clone().requires_grad_(require_grad)
+
+    def clone_inputs(preserve_requires_grad: bool):
+        inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+
+        for arg in args:
+            if isinstance(arg, torch.Tensor):
+                inputs.append(clone_tensor(arg, preserve_requires_grad))
+            elif is_iterable_of_tensors(arg):
+                inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
+            else:
+                inputs.append(arg)
+
+        return inputs
+
+    # Returns tensors in args that requires_grad, including tensors in TensorList args
+    def get_recording_tensors(args):
+        recording_tensors: list[torch.Tensor] = []
+
+        for arg in args:
+            if isinstance(arg, torch.Tensor) and arg.requires_grad:
+                recording_tensors.append(arg)
+            elif is_iterable_of_tensors(arg):
+                recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
+
+        return recording_tensors
+
+    # test no gradients case
+    nograd_inputs = clone_inputs(preserve_requires_grad=False)
+    outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
+    with enable_profiling_mode_for_profiling_tests():
+        outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
+    self.assertEqual(outputs, outputs_test)
+
+    if check_types:
+        check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
+
+    if no_grad:
+        # skip grad tests
+        return
+
+    with enable_profiling_mode_for_profiling_tests():
+        # test single grad case
+        recording_inputs = clone_inputs(preserve_requires_grad=True)
+        recording_tensors = get_recording_tensors(recording_inputs)
+        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
+        grads = torch.autograd.grad(allSum(outputs), recording_tensors,
+                                    allow_unused=allow_unused)
+        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
+        grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
+                                         allow_unused=allow_unused)
+        self.assertEqual(outputs, outputs_test)
+        self.assertEqual(grads, grads_test)
+        # test the grad grad case
+        if self._testMethodName in nn_functional_single_grad or no_gradgrad:
+            return
+
+        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
+        l1 = allSum(outputs)
+        grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
+                                    allow_unused=allow_unused)
+
+        l2 = (allSum(grads) * l1)
+        grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
+        recording_inputs = clone_inputs(preserve_requires_grad=True)
+        recording_tensors = get_recording_tensors(recording_inputs)
+        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
+        l1_test = allSum(outputs_test)
+        grads_test = torch.autograd.grad(
+            l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
+
+        l2_test = (allSum(grads_test) * l1_test)
+        grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
+
+        self.assertEqual(outputs, outputs_test)
+        self.assertEqual(grads, grads_test)
+        for g2, g2_test in zip(grads2, grads2_test, strict=True):
+            if g2 is None and g2_test is None:
+                continue
+            self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
+
+class JitCommonTestCase(TestCase):
+    def createFunctionFromGraph(self, trace):
+        graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
+        return torch._C._create_function_from_graph("forward", graph)
+
+    def assertExportImport(self, trace, inputs):
+        m = self.createFunctionFromGraph(trace)
+        self.assertExportImportModule(m, inputs)
+
+    def assertExportImportModule(self, m, inputs):
+        m_import = self.getExportImportCopy(m)
+        a = self.runAndSaveRNG(m, inputs)
+        b = self.runAndSaveRNG(m_import, inputs)
+        self.assertEqual(a, b, "Results of original model and "
+                               "exported/imported version of model differed")
+
+    def runAndSaveRNG(self, func, inputs, kwargs=None):
+        kwargs = kwargs if kwargs else {}
+        with freeze_rng_state():
+            results = func(*inputs, **kwargs)
+        return results
+
+    def getExportImportCopy(self, m, also_test_file=True, map_location=None):
+        buffer = io.BytesIO()
+        torch.jit.save(m, buffer)
+        buffer.seek(0)
+        imported = torch.jit.load(buffer, map_location=map_location)
+
+        if not also_test_file:
+            return imported
+
+        with TemporaryFileName() as fname:
+            torch.jit.save(imported, fname)
+            return torch.jit.load(fname, map_location=map_location)
+
+    def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
+                             fusion_nodes_not_found, non_fusible_nodes_being_fused,
+                             fusion_nodes_found, nodes_in_diff_graph):
+        err_msg = "\nFailure in testing nodes' autodifferentiation. "
+        if should_autodiff_node:
+            err_msg += "One or more nodes were expected to be autodiffed, " \
+                "but were not found in specified fusible/nonfusible " \
+                "DifferentiableGraph groups. \nSpecifically:"
+            # The node is intended to appear in a differentiable graph but doesn't
+            diff_nodes_missing = []
+            # The node is intended to appear in a differentiable graph
+            # outside of a fusion group but instead is in a fusion group
+            diff_nodes_in_fusion = []
+            # The node is intended to appear in a fusion group but doesn't
+            fusion_nodes_missing = []
+            # The node is intended to appear in a fusion group but instead
+            # is just in an outer differentiable graph
+            fusion_nodes_in_diff = []
+            for node in nodes_not_in_diff_graph:
+                if node in non_fusible_nodes_being_fused:
+                    diff_nodes_in_fusion.append(node)
+                else:
+                    diff_nodes_missing.append(node)
+            for node in fusion_nodes_not_found:
+                if node in nodes_in_diff_graph:
+                    fusion_nodes_in_diff.append(node)
+                else:
+                    fusion_nodes_missing.append(node)
+            if len(diff_nodes_missing) > 0:
+                err_msg += f"\n  {diff_nodes_missing} were not in one of the " \
+                    "DifferentiableGraphs when they were expected to be. " \
+                    "Did you intend for these nodes to be autodiffed? " \
+                    "If not, remove them from the list of nonfusible nodes."
+            if len(diff_nodes_in_fusion) > 0:
+                err_msg += f"\n  {diff_nodes_in_fusion} were found in one of the FusionGroups " \
+                    "when they were expected to be just in a DifferentiableGraph. If it was " \
+                    "intended for these nodes to be in FusionGroups, reclassify these nodes as " \
+                    "fusible nodes. If these nodes were not intended to be fused, your " \
+                    "autodifferentiation logic might be wrong."
+            if len(fusion_nodes_missing) > 0:
+                err_msg += f"\n  {fusion_nodes_missing} were not in one of the FusionGroups " \
+                    "of the DifferentiableGraphs when they were expected to be. " \
+                    "They were also not found in an outer DifferentiableGraph. Did you " \
+                    "intend for these nodes to be autodifferentiated? If not, you should " \
+                    "remove these nodes from the test's fusible nodes. Otherwise your " \
+                    "autodifferentiation logic might be wrong."
+            if len(fusion_nodes_in_diff) > 0:
+                err_msg += f"\n  {fusion_nodes_in_diff} were not in one of the FusionGroups " \
+                    "of the DifferentiableGraphs when they were expected to be, " \
+                    "instead they were found just in an outer DifferentiableGraph. " \
+                    "Did you intend for these nodes to be fused? If not, you should " \
+                    "move these nodes into the test's nonfusible nodes. Otherwise your " \
+                    "autodifferentiation logic might be wrong."
+        else:
+            err_msg += "One or more nodes were not expected to be autodiffed " \
+                "but were found in a DifferentiableGraph or in a FusionGroup " \
+                "of a DifferentiableGraph. Did you intend for these nodes to be " \
+                "autodiffed? If so, change this test to expect autodifferentiation. " \
+                "\nSpecifically:"
+            if len(fusion_nodes_found) > 0:
+                err_msg += f"\n  {fusion_nodes_found} were not expected to be in " \
+                    "one of the DifferentiableGraphs, but appeared in a FusionGroup " \
+                    "of a DifferentiableGraph. "
+            if len(nodes_in_diff_graph) > 0:
+                err_msg += f"\n  {nodes_in_diff_graph} were not expected to " \
+                    "be in one of the DifferentiableGraphs but were."
+        return err_msg
+
+    def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
+        diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
+        diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
+
+        # Note: currently no tests have fusible_nodes
+        fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
+        fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
+
+        # For any non-fusible node, it must show up in one of the DifferentiableGraphs.
+        nodes_in_diff_graph = []
+        nodes_not_in_diff_graph = []
+        non_fusible_nodes_being_fused = []
+        for node in nonfusible_nodes:
+            if any(g.findNode(node) is not None for g in diff_subgraphs):
+                nodes_in_diff_graph.append(node)
+            else:
+                nodes_not_in_diff_graph.append(node)
+            if any(g.findNode(node) is not None for g in fusion_subgraphs):
+                non_fusible_nodes_being_fused.append(node)
+        found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
+
+        # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
+        fusion_nodes_found = []
+        fusion_nodes_not_found = []
+        for node in fusible_nodes:
+            if any(g.findNode(node) is not None for g in fusion_subgraphs):
+                fusion_nodes_found.append(node)
+            else:
+                fusion_nodes_not_found.append(node)
+        found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
+
+        if should_autodiff_node is not None:
+            err_msg = self.autoDiffErrorMessage(should_autodiff_node,
+                                                nodes_not_in_diff_graph,
+                                                fusion_nodes_not_found,
+                                                non_fusible_nodes_being_fused,
+                                                fusion_nodes_found,
+                                                nodes_in_diff_graph)
+            self.assertEqual(should_autodiff_node,
+                             found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
+
+    def checkShapeAnalysis(self, out_sizes: Union[list[int], list[list[int]]],
+                           traced_graph, assert_propagation, constant_prop=True):
+        # repropagte input shapes provided by tracing,
+        prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
+        for enable_test_mode in [True, False]:
+            # here we are testing allowing/disallowing substituting in complete shapes as constants,
+            # disallowing constants helps stress test partial eval and substitution pipeline
+            torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
+            torch._C._jit_erase_non_input_shape_information(traced_graph)
+            if constant_prop:
+                torch._C._jit_pass_constant_propagation(traced_graph)
+            torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
+            # Add sizes to default tensor type to avoid checking something out of scope
+            # and difficulties with tracer leaving in other parts of tensor type
+            output = next(traced_graph.outputs()).type()
+
+            def test_type(type, actual_size):
+                sizes = type.symbolic_sizes()
+                out_type = TensorType.get().with_sizes(sizes)
+                actual_type = TensorType.get().with_sizes(actual_size)
+
+                # always check actual shape is a subtype of the output
+                self.assertTrue(actual_type.isSubtypeOf(out_type))
+
+                # and then if assertion flag is provided, check shape analysis
+                # is successful
+                if assert_propagation:
+                    self.assertEqual(out_type.sizes(), actual_size)
+
+            if output.isSubtypeOf(torch._C.TensorType.get()):
+                test_type(output, out_sizes)
+            else:
+                tuple_elements = output.elements()
+                for i in range(len(tuple_elements)):
+                    test_type(tuple_elements[i], out_sizes[i])
+
+        torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..70ab98137bd712de4c5b0e998e26bd585ff4433c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py
@@ -0,0 +1,113 @@
+# mypy: ignore-errors
+
+import contextlib
+import functools
+import inspect
+
+import torch
+
+
+def bf32_is_not_fp32():
+    if not torch.backends.mkldnn.is_available():
+        return False
+    if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
+        return False
+    return True
+
+
+def tf32_is_not_fp32():
+    if not torch.backends.mkldnn.is_available():
+        return False
+    if not torch._C._cpu._is_amx_fp16_supported():
+        return False
+    return True
+
+
+@contextlib.contextmanager
+def reduced_f32_off():
+    old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
+    old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
+    try:
+        torch.backends.mkldnn.matmul.fp32_precision = "ieee"
+        torch.backends.mkldnn.conv.fp32_precision = "ieee"
+        yield
+    finally:
+        torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
+        torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
+
+
+@contextlib.contextmanager
+def bf32_on(self, bf32_precision=1e-2):
+    old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
+    old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
+    old_precision = self.precision
+    try:
+        torch.backends.mkldnn.matmul.fp32_precision = "bf16"
+        torch.backends.mkldnn.conv.fp32_precision = "bf16"
+        self.precision = bf32_precision
+        yield
+    finally:
+        torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
+        torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
+        self.precision = old_precision
+
+
+@contextlib.contextmanager
+def tf32_on(self, tf32_precision=1e-5):
+    old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
+    old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
+    old_precision = self.precision
+    try:
+        torch.backends.mkldnn.matmul.fp32_precision = "tf32"
+        torch.backends.mkldnn.conv.fp32_precision = "tf32"
+        self.precision = tf32_precision
+        yield
+    finally:
+        torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
+        torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
+        self.precision = old_precision
+
+
+# This is a wrapper that wraps a test to run this test three times, one with
+# reduced_f32 OFF, the others with reduced_f32 ON (including bf32 ON and tf32
+# ON). When running with reduced_f32 ON, it will use reduced precision (bf16/
+# tf32) as specified by the argument.
+def reduced_f32_on_and_off(bf32_precision=1e-2, tf32_precision=1e-5):
+    def with_reduced_f32_disabled(self, function_call):
+        with reduced_f32_off():
+            function_call()
+
+    def with_bf32_enabled(self, function_call):
+        with bf32_on(self, bf32_precision):
+            function_call()
+
+    def with_tf32_enabled(self, function_call):
+        with tf32_on(self, tf32_precision):
+            function_call()
+
+    def wrapper(f):
+        params = inspect.signature(f).parameters
+        arg_names = tuple(params.keys())
+
+        @functools.wraps(f)
+        def wrapped(*args, **kwargs):
+            kwargs.update(zip(arg_names, args, strict=False))
+            cond = True
+            if "device" in kwargs:
+                cond = cond and (torch.device(kwargs["device"]).type == "cpu")
+            if "dtype" in kwargs:
+                cond = cond and (kwargs["dtype"] == torch.float)
+            bf32_cond = cond and bf32_is_not_fp32()
+            tf32_cond = cond and tf32_is_not_fp32()
+            if bf32_cond or tf32_cond:
+                with_reduced_f32_disabled(kwargs["self"], lambda: f(**kwargs))
+                if bf32_cond:
+                    with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
+                if tf32_cond:
+                    with_tf32_enabled(kwargs["self"], lambda: f(**kwargs))
+            else:
+                f(**kwargs)
+
+        return wrapped
+
+    return wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..83fca0b973856ad05dcdd417f1f46f85bcd8591f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py
@@ -0,0 +1,4380 @@
+# mypy: ignore-errors
+
+import torch
+import unittest
+from copy import deepcopy
+from enum import Enum
+from functools import wraps, partial
+from itertools import chain, product
+import itertools
+import math
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pack_padded_sequence
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import TEST_CUDNN
+from torch.testing._internal.common_dtype import (
+    floating_types, floating_and_complex_types_and, get_all_fp_dtypes)
+from torch.testing._internal.common_device_type import (
+    _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol,
+    precisionOverride, skipMeta, skipMPS)
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+from torch.testing._internal.common_nn import (
+    cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference,
+    hingeembeddingloss_reference, huberloss_reference, kldivloss_reference,
+    marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference,
+    nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction)
+from torch.testing._internal.common_utils import (
+    freeze_rng_state, skipIfMPS, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS,
+    skipIfTorchDynamo)
+from types import ModuleType
+import operator
+
+# List of all namespaces containing modules to test.
+MODULE_NAMESPACES: list[ModuleType] = [
+    torch.nn.modules,
+    torch.ao.nn.qat.modules,
+    torch.ao.nn.quantizable.modules,
+    torch.ao.nn.quantized.modules,
+    torch.ao.nn.quantized.modules,
+]
+
+# Modules that shouldn't be tested for one reason or another.
+MODULES_TO_SKIP: set[type] = {
+    torch.nn.Module,  # abstract base class
+    torch.nn.Container,  # deprecated
+    torch.nn.NLLLoss2d,  # deprecated
+    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
+    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
+}
+
+# List of all module classes to test.
+MODULE_CLASSES: list[type] = [*chain.from_iterable([
+    [getattr(namespace, module_name) for module_name in namespace.__all__]  # type: ignore[attr-defined]
+    for namespace in MODULE_NAMESPACES])]
+MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP]
+
+# Dict of module class -> common name. Useful for making test names more intuitive.
+# Example: torch.nn.modules.linear.Linear -> "nn.Linear"
+MODULE_CLASS_NAMES: dict[type, str] = {}
+for namespace in MODULE_NAMESPACES:
+    for module_name in namespace.__all__:  # type: ignore[attr-defined]
+        module_cls = getattr(namespace, module_name)
+        namespace_name = namespace.__name__.replace('torch.', '').replace('.modules', '')
+
+        # Deal with any aliases by preferring earlier names.
+        if module_cls not in MODULE_CLASS_NAMES:
+            MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
+
+
+# Specifies the modes (i.e. train, eval) to test over.
+TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
+
+
+class modules(_TestParametrizer):
+    """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
+
+    def __init__(self, module_info_iterable, allowed_dtypes=None,
+                 train_eval_mode=TrainEvalMode.train_and_eval, skip_if_dynamo=True):
+        self.module_info_list = list(module_info_iterable)
+        self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
+        self.train_eval_mode = train_eval_mode
+        self.skip_if_dynamo = skip_if_dynamo
+
+    def _get_training_flags(self, module_info):
+        training_flags = []
+        if (self.train_eval_mode == TrainEvalMode.train_only or
+                self.train_eval_mode == TrainEvalMode.train_and_eval):
+            training_flags.append(True)
+
+        if (self.train_eval_mode == TrainEvalMode.eval_only or
+                self.train_eval_mode == TrainEvalMode.train_and_eval):
+            training_flags.append(False)
+
+        # If train and eval modes don't differ for the module, don't bother using more than one.
+        if not module_info.train_and_eval_differ:
+            training_flags = training_flags[:1]
+
+        return training_flags
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if device_cls is None:
+            raise RuntimeError('The @modules decorator is only intended to be used in a device-specific '
+                               'context; use it with instantiate_device_type_tests() instead of '
+                               'instantiate_parametrized_tests()')
+
+        for module_info in self.module_info_list:
+            dtypes = set(module_info.supported_dtypes(device_cls.device_type))
+            if self.allowed_dtypes is not None:
+                dtypes = dtypes.intersection(self.allowed_dtypes)
+
+            training_flags = self._get_training_flags(module_info)
+            for (training, dtype) in product(training_flags, dtypes):
+                # Construct the test name; device / dtype parts are handled outside.
+                # See [Note: device and dtype suffix placement]
+                test_name = module_info.formatted_name
+                if len(training_flags) > 1:
+                    test_name += f"_{'train_mode' if training else 'eval_mode'}"
+
+                # Construct parameter kwargs to pass to the test.
+                param_kwargs = {'module_info': module_info}
+                _update_param_kwargs(param_kwargs, 'dtype', dtype)
+                _update_param_kwargs(param_kwargs, 'training', training)
+
+                try:
+
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        return test(*args, **kwargs)
+
+                    if self.skip_if_dynamo and not torch.testing._internal.common_utils.TEST_WITH_TORCHINDUCTOR:
+                        test_wrapper = skipIfTorchDynamo("Policy: we don't run ModuleInfo tests w/ Dynamo")(test_wrapper)
+
+                    decorator_fn = partial(module_info.get_decorators, generic_cls.__name__,
+                                           test.__name__, device_cls.device_type, dtype)
+
+                    yield (test_wrapper, test_name, param_kwargs, decorator_fn)
+                except Exception as ex:
+                    # Provides an error message for debugging before rethrowing the exception
+                    print(f"Failed to instantiate {test_name} for module {module_info.name}!")
+                    raise ex
+
+
+def get_module_common_name(module_cls):
+    if module_cls in MODULE_CLASS_NAMES:
+        # Example: "nn.Linear"
+        return MODULE_CLASS_NAMES[module_cls]
+    else:
+        return module_cls.__name__
+
+
+class FunctionInput:
+    """ Contains args and kwargs to pass as input to a function. """
+    __slots__ = ['args', 'kwargs']
+
+    def __init__(self, *args, **kwargs):
+        self.args = args
+        self.kwargs = kwargs
+
+
+class ModuleInput:
+    """ Contains args / kwargs for module instantiation + forward pass. """
+    __slots__ = ['constructor_input', 'forward_input', 'desc', 'reference_fn']
+
+    def __init__(self, constructor_input, forward_input=None, desc='', reference_fn=None):
+        self.constructor_input = constructor_input  # Inputs to pass during construction
+        self.forward_input = forward_input  # Inputs to pass to forward()
+        self.desc = desc  # Description for this set of inputs
+        self.reference_fn = reference_fn  # Reference with signature: reference_fn(module, parameters, *args, **kwargs)
+
+        if reference_fn is not None:
+
+            @wraps(reference_fn)
+            def copy_reference_fn(m, *args, **kwargs):
+                # Copy inputs to avoid undesired side effects from calling the reference.
+                args, kwargs = deepcopy(args), deepcopy(kwargs)
+
+                # Note that module parameters are passed in for convenience.
+                return reference_fn(m, list(m.parameters()), *args, **kwargs)
+
+            self.reference_fn = copy_reference_fn
+
+class ModuleErrorEnum(Enum):
+    """ Enumerates when error is raised when testing modules. """
+    CONSTRUCTION_ERROR = 0
+    FORWARD_ERROR = 1
+
+class ErrorModuleInput:
+    """
+    A ModuleInput that will cause the operation to throw an error plus information
+    about the resulting error.
+    """
+
+    __slots__ = ["module_error_input", "error_on", "error_type", "error_regex"]
+
+    def __init__(self,
+                 module_error_input,
+                 *,
+                 error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+                 error_type=RuntimeError,
+                 error_regex):
+        self.module_error_input = module_error_input
+        self.error_on = error_on
+        self.error_type = error_type
+        self.error_regex = error_regex
+
+
+class ModuleInfo:
+    """ Module information to be used in testing. """
+
+    def __init__(self,
+                 module_cls,  # Class object for the module under test
+                 *,
+                 module_inputs_func,  # Function to generate module inputs
+                 skips=(),  # Indicates which tests to skip
+                 decorators=None,  # Additional decorators to apply to generated tests
+                 dtypes=floating_types(),  # dtypes this function is expected to work with
+                 dtypesIfMPS=(torch.float16, torch.float32,),  # dtypes this function is expected to work with on MPS
+                 dtypesIfHpu=(torch.bfloat16, torch.float32,),
+                 supports_gradgrad=True,  # whether the op supports second order gradients
+                 gradcheck_nondet_tol=0.0,  # tolerance for nondeterminism while performing gradcheck
+                 module_memformat_affects_out=False,  # whether converting module to channels last will generate
+                                                      # channels last output
+                 train_and_eval_differ=False,  # whether the module has differing behavior between train and eval
+                 module_error_inputs_func=None,  # Function to generate module inputs that error
+                 gradcheck_fast_mode=None,  # Whether to use the fast implementation for gradcheck/gradgradcheck.
+                                            # When set to None, defers to the default value provided by the wrapper
+                                            # function around gradcheck (testing._internal.common_utils.gradcheck)
+                 ):
+        self.module_cls = module_cls
+        self.module_inputs_func = module_inputs_func
+        self.decorators = (*(decorators if decorators else []), *(skips if skips else []))
+        self.dtypes = dtypes
+        self.dtypesIfMPS = dtypesIfMPS
+        self.dtypesIfHpu = dtypesIfHpu
+        self.supports_gradgrad = supports_gradgrad
+        self.gradcheck_nondet_tol = gradcheck_nondet_tol
+        self.module_memformat_affects_out = module_memformat_affects_out
+        self.train_and_eval_differ = train_and_eval_differ
+        self.module_error_inputs_func = module_error_inputs_func
+        self.gradcheck_fast_mode = gradcheck_fast_mode
+        self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin)
+
+    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(test_class, test_name, device, dtype, param_kwargs):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
+
+    def supported_dtypes(self, device_type):
+        if device_type == 'mps':
+            return self.dtypesIfMPS
+        elif device_type == 'hpu':
+            return self.dtypesIfHpu
+        else:
+            return self.dtypes
+
+    @property
+    def name(self):
+        return get_module_common_name(self.module_cls)
+
+    @property
+    def formatted_name(self):
+        return self.name.replace('.', '_')
+
+# Start of module inputs functions.
+
+def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    module_inputs = [
+        ModuleInput(constructor_input=FunctionInput(10, 8),
+                    forward_input=FunctionInput(input=make_input((4, 10))),
+                    reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
+        ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='no_bias',
+                    reference_fn=lambda m, p, i: torch.mm(i, p[0].t())),
+        ModuleInput(constructor_input=FunctionInput(3, 5),
+                    forward_input=FunctionInput(make_input(3)),
+                    desc='no_batch_dim',
+                    reference_fn=lambda m, p, i: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1])
+    ]
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def bilinear_reference_fn(m, p, x1, x2, bias=True):
+        result = torch.einsum('bn,anm,bm->ba', x1, p[0], x2)
+        if bias:
+            if x1.shape[0] == 1:
+                result = result.view(-1) + p[1]
+            else:
+                result = result + p[1].view(1, -1).expand(x1.shape[0], p[0].shape[0])
+        return result
+
+    module_inputs = [
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
+                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
+                    reference_fn=bilinear_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4, bias=False),
+                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
+                    desc='no_bias',
+                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2, bias=False)),
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
+                    forward_input=FunctionInput(make_input(2), make_input(3)),
+                    desc='no_batch_dim',
+                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1.view(1, -1), x2.view(1, -1))),
+    ]
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_batchmean', {'reduction': 'batchmean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('log_target', {'log_target': True})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return kldivloss_reference(i, t, **constructor_kwargs)
+
+        input = make_input((10, 10)).log()
+        target = make_input((10, 10)) if kwargs.get('log_target', False) else make_input((10, 10)).log()
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(input, target),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+        scalar_input = make_input(()).log()
+        # FIXME(rec): scalar_target is unused, perhaps should be argument to FunctionInput?
+        scalar_target = (  # noqa: F841
+            make_input(()) if kwargs.get('log_target', False) else make_input(()).log()
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(scalar_input, scalar_input),
+                        desc='scalar_' + desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    def make_input(shape, device=device, dtype=dtype, requires_grad=requires_grad):
+        return make_tensor(shape, device=device, dtype=dtype,
+                           requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('ignore_index', {'ignore_index': 2}),
+        ('weights', {'weight': make_weight(4).abs()}),
+        ('weights_ignore_index', {'weight': make_weight(4).abs(), 'ignore_index': 2}),
+        ('weights_ignore_index_neg', {'weight': make_weight(4).abs(), 'ignore_index': -1})
+    ]
+
+    # TODO: Uncomment when negative weights is supported.
+    # negative_weight = make_weight(10)
+    # negative_weight[0] = -1
+    # cases.append(('weights_negative', {'weight': negative_weight}))
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return nllloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 4)),
+                                                    torch.empty(15, device=device).uniform_().mul(4).floor().long()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+        def nd_reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return nlllossNd_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5, 5)),
+                            torch.empty(2, 5, 5, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"nd_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5, 5, 2, 2)),
+                            torch.empty(2, 5, 5, 2, 2, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"higher_dim_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5)),
+                            torch.empty(2, 5, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"3d_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('homoscedastic', {'homoscedastic': True}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        homoscedastic = constructor_kwargs.pop('homoscedastic', False)
+        var_input = make_input(1, 3).abs() if homoscedastic else make_input(4, 1).abs()
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(4, 3),
+                                                    make_target(4, 3),
+                                                    var_input),
+                        desc=desc,
+                        reference_fn=no_batch_dim_reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('full', {'full': True}),
+        ('no_log_input', {'log_input': False}),
+        ('full_no_log_input', {'full': True, 'log_input': False}),
+    ]
+
+    def poissonnllloss_reference_fn(i, t, log_input=True, full=False, reduction='mean', eps=1e-8):
+        if log_input:
+            result = i.exp() - t.mul(i)
+        else:
+            result = i - t.mul((i + eps).log())
+
+        if full:
+            result += (t.mul(t.log()) - t + 0.5 * (2. * math.pi * t).log()).masked_fill(t <= 1, 0)
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return poissonnllloss_reference_fn(i, t, **constructor_kwargs)
+
+        log_input = constructor_kwargs.get('log_input', True)
+        input = make_input((2, 3, 4, 5)) if log_input else make_input((2, 3, 4, 5)).abs().add(0.001)
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(input,
+                                                    make_target((2, 3, 4, 5)).floor_().abs_()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    def mse_loss_reference_fn(m, p, i, t, reduction='mean'):
+        if reduction == 'none':
+            return (i - t).pow(2)
+        elif reduction == 'mean':
+            return (i - t).pow(2).sum() / i.numel()
+        else:
+            return (i - t).pow(2).sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 4, 5)),
+                                                    make_target((2, 3, 4, 5))),
+                        desc=desc,
+                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_target(())),
+                        desc=f'{desc}_scalar',
+                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def no_batch_dim_reference_fn(m, p, *args, **kwargs):
+    """Reference function for modules supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+
+    Currently it only supports modules which return a single Tensor as output.
+    You can bind the following kwargs.
+    Kwargs:
+        batch_first[bool] : If True, all the Tensors in `args` while be unsqueezed at dim `0` .
+                        and output will be squeezed at dim `0` else dim `1` for both.
+        kwargs_to_batchify[dict] : Dictionary specifying the name of the argument and dimension to unsqueeze.
+                               Useful if there are few arguments whose batch dimension are different
+                               from the ones selected by `batch_first`.
+        is_criterion[bool] : Specify if the module is a criterion and handle the reduction for output accordingly.
+    """
+    def get_and_pop(key, default):
+        v = kwargs.get(key, default)
+        if key in kwargs:
+            kwargs.pop(key)
+        return v
+
+    batch_dim = 0 if get_and_pop('batch_first', True) else 1
+    kwargs_to_batchify = get_and_pop('kwargs_to_batchify', None)
+    is_criterion = get_and_pop('is_criterion', False)
+
+    if kwargs_to_batchify is not None:
+        assert isinstance(kwargs_to_batchify, dict)
+        for k, v in kwargs.items():
+            if k in kwargs_to_batchify and v is not None:
+                bdim = kwargs_to_batchify[k]
+                kwargs[k] = v.unsqueeze(bdim)
+
+    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs).squeeze(batch_dim)
+
+    if is_criterion:
+        reduction = get_reduction(m)
+        if reduction == 'none':
+            return output.squeeze(0)
+    return output
+
+
+def no_batch_dim_reference_mha(m, p, *args, **kwargs):
+    """Reference function for MultiheadAttention supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    batch_dim = 0 if kwargs.get('batch_first', True) else 1
+    if 'batch_first' in kwargs:
+        kwargs.pop('batch_first')
+    if 'key_padding_mask' in kwargs and kwargs['key_padding_mask'] is not None:
+        kwargs['key_padding_mask'] = kwargs['key_padding_mask'].unsqueeze(0)
+    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), output[1].squeeze(0))
+
+
+def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
+    """Reference function for RNN and GRU supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    if len(args) == 1:
+        inp, = args
+        h = None
+    elif len(args) == 2:
+        inp, h = args
+        h = h.unsqueeze(1)
+
+    batch_dim = 0 if kwargs['batch_first'] else 1
+    kwargs.pop('batch_first')
+    inp = inp.unsqueeze(batch_dim)
+    single_batch_input_args = (inp, h)
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), output[1].squeeze(1))
+
+
+def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
+    """Reference function for LSTM supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    if len(args) == 1:
+        inp, = args
+        h = None
+    elif len(args) == 2:
+        inp, h = args
+        h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
+
+    batch_dim = 0 if kwargs['batch_first'] else 1
+    kwargs.pop('batch_first')
+    inp = inp.unsqueeze(batch_dim)
+    single_batch_input_args = (inp, h)
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
+
+
+def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
+    """Reference function for LSTMCell supporting no batch dimensions.
+
+    The module is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    inp, (h, c) = args
+    single_batch_input_args = (inp.unsqueeze(0), (h.unsqueeze(0), c.unsqueeze(0)))
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(0), output[1].squeeze(0))
+
+
+def generate_regression_criterion_inputs(make_input):
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(reduction=reduction),
+            forward_input=FunctionInput(make_input((4, )), make_input(4,)),
+            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True),
+            desc=f'no_batch_dim_{reduction}'
+        ) for reduction in ['none', 'mean', 'sum']]
+
+
+def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(kernel_size=2),
+                    forward_input=FunctionInput(make_input((3, 6))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(2),
+                    forward_input=FunctionInput(make_input((2, 3, 6)))),
+        ModuleInput(constructor_input=FunctionInput((2,), (2,)),
+                    forward_input=FunctionInput(make_input((2, 3, 6))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, 1),
+                    forward_input=FunctionInput(make_input((2, 3, 6))),
+                    desc='stride_pad')]
+
+
+def module_inputs_torch_nn_AvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput((2, 2)),
+                    forward_input=FunctionInput(make_input((3, 6, 6))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput((2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='stride_pad'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor_stride'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor_stride_pad')]
+
+
+
+def module_inputs_torch_nn_AvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
+                    forward_input=FunctionInput(make_input((3, 4, 4, 4))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
+        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride_pad'),
+        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride_pad_gpu_fixedkw_output'),
+        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
+                    desc='stride_pad_gpu_general_output'),
+        ModuleInput(constructor_input=FunctionInput(3, 1, 0),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='stride1_pad0_gpu_input'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='stride_pad_gpu_input_nooverlap'),
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor'),
+        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride_pad'),
+        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride_pad_gpu_fixedkw_output'),
+        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
+                    desc='divisor_stride_pad_gpu_general_output'),
+        ModuleInput(constructor_input=FunctionInput(3, 1, 0, divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor_stride1_pad0_gpu_input'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor_stride_pad_gpu_input_nooverlap')]
+
+
+
+def module_inputs_torch_nn_AdaptiveAvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='one_output')]
+
+
+def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single_1x1output'),
+        ModuleInput(constructor_input=FunctionInput((3, 4)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple_none')]
+
+def module_inputs_torch_nn_AdaptiveAvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 2, 7))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 2, 7))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((None, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
+                    desc='tuple_none'),
+        ModuleInput(constructor_input=FunctionInput((3, 2, 2)),
+                    forward_input=FunctionInput(make_input((1, 1, 3, 2, 6))),
+                    desc='last_dim')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple_none')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6, 7))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='tuple_none'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 12, 9, 3))),
+                    desc='single_nonatomic'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 4, 10))),
+                    desc='tuple_nonatomic')]
+
+
+def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(10,),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='affine'),
+        ModuleInput(constructor_input=FunctionInput(5,),
+                    forward_input=FunctionInput(make_input((4, 5, 3))),
+                    desc='3d_input'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, None),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='affine_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, True, False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((4, 5, 3))),
+                    desc='3d_input_not_affine'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 9))),
+                    desc='zero_batch')]
+
+
+def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='2d_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='momentum'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, False),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, True, False),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 2, 2))),
+                    desc='zero_batch')]
+
+
+def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='3d_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='momentum'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, False),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, True, False),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))),
+                    desc='zero_batch')]
+
+
+def module_error_inputs_torch_nn_BatchNorm1d_2d_3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    if module_info.module_cls == torch.nn.BatchNorm1d:
+        input_shape = (2, 10)
+    elif module_info.module_cls == torch.nn.BatchNorm2d:
+        input_shape = (2, 10, 5, 5)
+    else:
+        input_shape = (2, 10, 4, 4, 4)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, eps=-1.0),
+                forward_input=FunctionInput(make_input(input_shape)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="eps must be positive"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, eps=0.0),
+                forward_input=FunctionInput(make_input(input_shape)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="eps must be positive"
+        ),
+    ]
+
+
+def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
+    N = kwargs['N']
+    lazy = kwargs.get('lazy', False)
+    transposed = kwargs.get('transposed', False)
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}]
+    kernel_size, C_in, C_out = 3, 4, 5
+    input_no_batch_shape = (C_in,) + tuple(i + 3 for i in range(N))
+    input_batch_shape = (2,) + input_no_batch_shape
+    return [
+        ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else
+                                       FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)),
+                    forward_input=FunctionInput(make_input(
+                        input_batch_shape if with_batch else input_no_batch_shape)),
+                    desc=('' if with_batch else 'no_batch_dim'),
+                    reference_fn=(None if with_batch else no_batch_dim_reference_fn))
+        for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list)
+    ]
+
+
+def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.7})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
+            return cosineembeddingloss_reference(i1, i2, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10)), make_input((15, 10)),
+                                                    make_target((15,)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 5))),
+                    desc='4d_input')]
+
+
+def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_GLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6)))),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    desc='dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((4,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_GELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput('none'),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput('none'),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='channels_last_mem_format'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+                    desc='channels_last_3d_mem_format')]
+
+
+def module_inputs_torch_nn_ReLU6(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='channels_last_mem_format'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+                    desc='channels_last_3d_mem_format')]
+
+
+def module_inputs_torch_nn_LeakyReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(0.5),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    desc='with_negval'),
+        ModuleInput(constructor_input=FunctionInput(0.0),
+                    forward_input=FunctionInput(make_input((10, 10))),
+                    desc='with_zero_negval'),
+        ModuleInput(constructor_input=FunctionInput(0.5),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='with_negval_scalar')]
+
+
+def module_inputs_torch_nn_PReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='1d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='1d_multiparam'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='2d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='2d_multiparam'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='3d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='3d_multiparam')]
+
+
+def module_inputs_torch_nn_SELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar')]
+
+
+def module_inputs_torch_nn_SiLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x))]
+
+
+def module_inputs_torch_nn_Softmax(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20))),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(0, True)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softmax2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, False))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_LogSoftmax(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_()),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
+                    desc='multiparam'),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
+                    desc='multiparam_scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softmin(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20)))),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 10))),
+                    desc='multidim'),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((3, 4, 10))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.log1p(torch.exp(i))),
+        ModuleInput(constructor_input=FunctionInput(2),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: 1. / 2. * torch.log1p(torch.exp(2 * i)),
+                    desc='beta'),
+        ModuleInput(constructor_input=FunctionInput(2, -100),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=(
+                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
+                    desc='beta_threshold'),
+        ModuleInput(constructor_input=FunctionInput(2, -100),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=(
+                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
+                    desc='beta_threshold_scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    desc='lambda'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='lambda_scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softsign(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Tanh(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+
+def module_inputs_torch_nn_Tanhshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Threshold(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='threshold_value'),
+        ModuleInput(constructor_input=FunctionInput(2., 10.),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='large_value'),
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='threshold_value_scalar'),
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Mish(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4)),
+                                                make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
+                                                                         for a, b in zip(i, t, strict=True))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(()), make_input(())),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
+                    desc='scalar')] + generate_regression_criterion_inputs(make_input)
+
+
+def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return smoothl1loss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_input((5, 10))),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_input(())),
+                        desc=f'scalar_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+
+def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weights', {'weight': make_weight((10,))}),
+    ]
+
+    def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        result = -(t * i.log() + (1 - t) * (1 - i).log())
+
+        if weight is not None:
+            result = result * weight
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
+                                                    make_target((15, 10)).gt(0).to(dtype)),
+                        desc=desc,
+                        reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs))
+        )
+
+    scalar_weight = make_weight(())
+    module_inputs.append(
+        ModuleInput(constructor_input=FunctionInput(weight=scalar_weight),
+                    forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2),
+                                                make_target(()).gt(0).to(dtype)),
+                    desc='scalar_weight',
+                    reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight))
+    )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weights', {'weight': make_weight((10,))}),
+        ('scalar_weights', {'weight': make_weight(())})
+    ]
+
+    def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        # TODO: add pos_weight to the definition here and corresponding SampleInputs
+        max_val = (-i).clamp(min=0)
+        result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_())
+
+        if weight is not None:
+            result = result * weight
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
+                                                    make_target((15, 10)).gt(0).to(dtype)),
+                        desc=desc,
+                        reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    reductions: list[str] = ['mean', 'sum', 'none']
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('weights', {'weight': make_weight((3,))}),
+        ('ignore_index', {'ignore_index': 1}),
+        ('label_smoothing', {'label_smoothing': 0.15}),
+        ('ignore_index_label_smoothing', {'ignore_index': 1, 'label_smoothing': 0.15})
+    ]
+
+    module_inputs = []
+    for reduction, (desc, constructor_kwargs) in product(reductions, cases):
+        def reference_fn(m, p, i, t, reduction=reduction, constructor_kwargs=constructor_kwargs):
+            return cross_entropy_loss_reference(i, t, reduction=reduction, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5, 5)),
+                                                    make_target((2, 5, 5), low=0, high=3)),
+                        desc=f"4d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5)),
+                                                    make_target((2, 5), low=0, high=3)),
+                        desc=f"3d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3)),
+                                                    make_target((2), low=0, high=3)),
+                        desc=f"2d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
+                                                    make_target((2, 5, 5, 2, 2), low=0, high=3)),
+                        desc=f"higher_dim_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+
+        if constructor_kwargs.get('ignore_index', None) is None:
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3, 4, 2)),
+                                                        make_input((5, 3, 4, 2)).softmax(dim=1)),
+                            desc=f"4d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3, 4)),
+                                                        make_input((5, 3, 4)).softmax(dim=1)),
+                            desc=f"3d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3)),
+                                                        make_input((5, 3)).softmax(dim=1)),
+                            desc=f"2d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
+                                                        make_input((2, 3, 5, 5, 2, 2)).softmax(dim=1)),
+                            desc=f"higher_dim_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((3,)),
+                                                        make_target((), low=0, high=3)),
+                            desc=f"no_batch_dim_{desc}_{reduction}",
+                            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
+            )
+
+    return module_inputs
+
+
+
+def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('blank', {'blank': 14})
+    ]
+    target_dtypes = [torch.int, torch.long]
+
+    module_inputs = []
+    for target_dtype, (desc, constructor_kwargs) in product(target_dtypes, cases):
+        def reference_fn(m, p, i, t, il, tl, constructor_kwargs=constructor_kwargs):
+            return ctcloss_reference(i, t, il, tl, **constructor_kwargs)
+
+        blank = constructor_kwargs.get('blank', 0)
+        low = 0 if blank == 14 else 1
+        high = 14 if blank == 14 else 15
+
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
+                                            (50, 50, 50), (30, 25, 20)),
+                desc=f'{desc}_lengths_intlists',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
+                                            torch.tensor((50, 50, 50), device=device),
+                                            torch.tensor((30, 25, 20), device=device)),
+                desc=f'{desc}_lengths_tensors',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
+                                            (50, 50, 50), (30, 25, 20)),
+                desc=f'{desc}_1d_target_lengths_intlists',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
+                                            torch.tensor((50, 50, 50), device=device),
+                                            torch.tensor((30, 25, 20), device=device)),
+                desc=f'{desc}_1d_target_lengths_tensors',
+                reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(3, 6, 1e-3),
+            forward_input=FunctionInput(make_input((4, 6, 5))),
+            desc='1d_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 12, 1e-3),
+            forward_input=FunctionInput(make_input((4, 12))),
+            desc='1d_affine_GN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 6, 1e-3),
+            forward_input=FunctionInput(make_input((150, 6))),
+            desc='1d_affine_large_batch'),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 5, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_affine_IN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 10, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 10))),
+            desc='1d_no_affine_LN'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 6, 1e-3),
+            forward_input=FunctionInput(make_input((4, 6, 2, 3))),
+            desc='2d_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 3, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
+            desc='2d_no_affine_IN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
+            desc='2d_no_affine_LN'),
+    ]
+
+
+def module_error_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    """
+    Error inputs for GroupNorm that test error messages include actual values.
+    """
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(3, 10),  # num_groups=3, num_channels=10
+                forward_input=FunctionInput(),  # Not needed for construction error
+            ),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex=r"num_channels \(10\) must be divisible by num_groups \(3\)"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(5, 13),  # num_groups=5, num_channels=13
+                forward_input=FunctionInput(),  # Not needed for construction error
+            ),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex=r"num_channels \(13\) must be divisible by num_groups \(5\)"
+        ),
+    ]
+
+
+def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2.),
+            forward_input=FunctionInput(make_input((4, 3, 2, 4))),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(2.),
+            forward_input=FunctionInput(make_input(())),
+            desc='scalar',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        )
+    ]
+
+
+def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 2, 5))),
+            desc='4d_input')
+    ]
+
+
+def module_inputs_torch_nn_Hardtanh(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((3, 2, 5))),
+            reference_fn=lambda m, p, i: i.clamp(-1, 1),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            reference_fn=lambda m, p, i: i.clamp(-1, 1),
+            desc='scalar',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        )
+    ]
+
+
+def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.5})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return hingeembeddingloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((10,)),
+                                                    make_target((10,)).gt(0).to(dtype).mul_(2).sub_(1)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_target(()).gt(0).to(dtype).mul_(2).sub_(1)),
+                        desc=f'scalar_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return huberloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_input((5, 10))),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    lazy = kwargs.get('lazy', False)
+    N = kwargs['N']
+    num_features, eps, momentum, affine, track_running_stats = 3, 1e-3, 0.3, False, True
+    input_no_batch_shape_dict = {1: (3, 15), 2: (3, 6, 6), 3: (3, 4, 4, 4)}
+    input_no_batch_shape = input_no_batch_shape_dict[N]
+    input_batch_shape = (4,) + input_no_batch_shape
+
+    return [
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
+            ),
+            forward_input=FunctionInput(make_input(input_batch_shape))),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
+                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
+            ),
+            forward_input=FunctionInput(make_input(input_batch_shape)),
+            desc='tracking_stats'),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
+            ),
+            forward_input=FunctionInput(make_input(input_no_batch_shape)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='tracking_stats_no_batch_dim'),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
+                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
+            ),
+            forward_input=FunctionInput(make_input(input_no_batch_shape)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim')
+    ]
+
+def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((128, 5, 5))),
+            desc='1d_elementwise_affine_large_batch'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_no_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((0, 5))),
+            desc='1d_empty_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, elementwise_affine=True, bias=False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine_no_bias'),
+    ]
+
+def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def rms_norm_reference_fn(m, p, i):
+        eps = m.eps
+        if eps is None:
+            eps = torch.finfo(i.dtype).eps
+        ndim = i.ndim
+        normalized_shape = m.normalized_shape
+        weight = m.weight
+        dims = [ndim - i - 1 for i in range(len(normalized_shape))]
+        upcasted_i = i.float()
+        result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
+        if weight is not None:
+            result *= weight
+        return result.type_as(i)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((128, 5, 5))),
+            desc='1d_elementwise_affine_large_batch',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_no_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((0, 5))),
+            desc='1d_empty_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+    ]
+
+
+def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(3,),
+            forward_input=FunctionInput(make_input((1, 5, 7))),
+            desc='1d'),
+        ModuleInput(
+            constructor_input=FunctionInput(2,),
+            forward_input=FunctionInput(make_input((1, 5, 7, 7))),
+            desc='2d_uneven_pad'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 1., 0.5, 2.),
+            forward_input=FunctionInput(make_input((1, 5, 7, 7, 7))),
+            desc='3d_custom_params'),
+    ]
+
+
+def module_inputs_torch_nn_LPPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7))),
+            desc='norm'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 3),
+            forward_input=FunctionInput(make_input((1, 3, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 3),
+            forward_input=FunctionInput(make_input((3, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+    ]
+
+
+
+def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((3, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='norm'),
+    ]
+
+
+def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((3, 7, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))),
+            desc='norm'),
+    ]
+
+
+def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(4),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='3d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 4),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='stride'),
+        ModuleInput(
+            constructor_input=FunctionInput(4, return_indices=True),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='return_indices'),
+    ]
+
+
+def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
+            forward_input=FunctionInput(make_input((3, 7, 7))),
+            desc='3d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='4d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1), return_indices=True),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='return_indices'),
+    ]
+
+def module_inputs_torch_nn_MaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, (2, 2, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='stride'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='stride_padding'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, (1, 1, 1), return_indices=True),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='return_indices'),
+    ]
+
+
+def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_random_samples():
+        return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_()
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
+            desc='ratio'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((1, 3, 7, 6))),
+            desc='size'),
+        ModuleInput(
+            constructor_input=FunctionInput(
+                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
+            ),
+            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
+            desc='ratio_return_indices'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((3, 5, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='ratio_no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((3, 7, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='size_no_batch_dim'),
+    ]
+
+
+def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_random_samples():
+        return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_()
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
+            desc='ratio'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 7, 7, 7))),
+            desc='size'),
+        ModuleInput(
+            constructor_input=FunctionInput((4, 2, 3), output_size=(10, 3, 2), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 16, 7, 5))),
+            desc='asymsize'),
+        ModuleInput(
+            constructor_input=FunctionInput(
+                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
+            ),
+            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
+            desc='ratio_return_indices'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((4, 5, 5, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='ratio_no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((4, 7, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='size_no_batch_dim'),
+    ]
+
+
+def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            desc='scalar'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            desc='channels_last_mem_format'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+            desc='channels_last_3d_mem_format'
+        )
+    ]
+
+
+def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            reference_fn=lambda m, p, i: i.sigmoid().log(),
+            desc='scalar'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 4))),
+            reference_fn=lambda m, p, i: i.sigmoid().log(),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+    ]
+
+
+def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.5})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
+            return marginrankingloss_reference(i1, i2, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((50,)), make_input((50,)),
+                                                    make_target((50,)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return multilabelmarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((10,)),
+                                                    make_target((10), low=0, high=10)),
+                        desc=f'1d_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5, 10), low=0, high=10)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('p', {'p': 2}),
+        ('margin', {'margin': 0.5}),
+        ('weights', {'weight': make_weight(10)})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return multimarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5), low=0, high=10)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weight', {'weight': make_weight(10)}),
+    ]
+
+    def multilabelsoftmargin_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        result = t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()
+        if weight is not None:
+            result *= weight
+        result = (-result).sum(i.dim() - 1) / i.size(-1)
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.mean()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5, 10), low=0, high=2)),
+                        desc=desc,
+                        reference_fn=partial(multilabelsoftmargin_loss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return softmarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 5)),
+                                                    make_target((5, 5)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Reuse the TransformerEncoderLayer samples since the forward args are nearly the same.
+    samples = []
+    for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer(
+            None, device, dtype, requires_grad, training):
+        # Construct a TransformerEncoderLayer object to pass to TransformerEncoder.
+        l_args, l_kwargs = (layer_module_input.constructor_input.args,
+                            layer_module_input.constructor_input.kwargs)
+        l_kwargs['device'] = device
+        l_kwargs['dtype'] = dtype
+        encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs)
+        num_layers = 2
+        # Note: TransformerEncoderLayer takes a "src_mask" while
+        # TransformerEncoder takes a "mask"; rename kwarg appropriately.
+        forward_input = layer_module_input.forward_input
+        if 'src_mask' in forward_input.kwargs:
+            forward_input.kwargs['mask'] = forward_input.kwargs['src_mask']
+            del forward_input.kwargs['src_mask']
+        samples.append(ModuleInput(
+            constructor_input=FunctionInput(encoder_layer, num_layers),
+            forward_input=forward_input,
+            desc=layer_module_input.desc
+        ))
+    return samples
+
+def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 16, 0.0),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='relu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='gelu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='no_bias'
+        ), ]
+
+    # Samples below are for validating the no-batch-dim support.
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+
+    # Samples below where we pass reference_fn are for validating the fast path,
+    # since the fast path requires no_grad mode, we run the fast path in .eval()
+    # and no_grad() in the reference_fn and verify that against the results in train mode.
+    def fast_path_reference_fn(module, parameters, *args, **kwargs):
+        assert module.training
+        module.train(False)
+        with torch.no_grad():
+            output = module(*args, **kwargs)
+        module.train(True)
+        return output
+
+    if training:
+        for norm_first, bias in itertools.product((True, False), (True, False)):
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(
+                        4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias
+                    ),
+                    forward_input=FunctionInput(
+                        make_input((2, 3, 4)),
+                    ),
+                    # fastpath doesn't run when bias=False
+                    reference_fn=fast_path_reference_fn if bias else None,
+                    desc=f'fastpath_{bias}_norm_first_{norm_first}'
+                )
+            )
+
+    return samples
+
+
+def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 16, 0.0),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='relu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='gelu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='no_bias'
+        ), ]
+
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        # Using same mask for tgt and memory
+        memory_mask = tgt_mask
+        memory_key_padding_mask = tgt_key_padding_mask
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first,
+                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
+        if not batch_first:
+            src, tgt = src.transpose(0, 1), tgt.transpose(0, 1)
+        if tgt_key_padding_mask is not None:
+            memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
+                ),
+                desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}'
+            ))
+
+    return samples
+
+
+def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = []
+    # Samples below are for validating the no-batch-dim support.
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for mask, key_padding_mask, norm_first, bias, batch_first in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        # Using same mask for tgt and memory
+        src_mask , tgt_mask = (mask,) * 2
+        src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                num_encoder_layers=1, num_decoder_layers=1,
+                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first,
+                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+
+        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
+        if not batch_first:
+            src = src.transpose(0, 1)
+            tgt = tgt.transpose(0, 1)
+        if key_padding_mask is not None:
+            src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2
+
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                num_encoder_layers=1, num_decoder_layers=1,
+                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    src, tgt, tgt_mask=tgt_mask, src_mask=src_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+            ))
+    return samples
+
+
+def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(2, 3).random_(4))
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
+            desc='discontiguous'
+        ),
+    ]
+
+
+def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = []
+    bool_vals = (True, False)
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3, 3)))
+    products = itertools.product(bool_vals, bool_vals, bool_vals, key_padding_masks, attn_masks)
+    for bias, add_bias_kv, add_zero_attn, key_padding_mask, attn_mask in products:
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=True,
+                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
+                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
+                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
+                reference_fn=no_batch_dim_reference_mha,
+            )
+        )
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=False,
+                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
+                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
+                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
+                reference_fn=partial(no_batch_dim_reference_mha, batch_first=False),
+            )
+        )
+
+    return samples
+
+
+def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10),
+            forward_input=FunctionInput(make_input(5), make_input(10)),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10, bias=True),
+            forward_input=FunctionInput(make_input(5), make_input(10)),
+            reference_fn=no_batch_dim_reference_fn,
+        )
+    ]
+
+    is_rnn = kwargs.get('is_rnn', False)
+    if is_rnn:
+        # RNN also supports `nonlinearity` argument.
+        # `tanh` is the default, so we check with `relu`
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(5, 10, bias=True, nonlinearity='relu'),
+                forward_input=FunctionInput(make_input(5), make_input(10)),
+                reference_fn=no_batch_dim_reference_fn,
+            )
+        )
+
+    return samples
+
+
+def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = (
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10),
+            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
+            reference_fn=no_batch_dim_reference_lstmcell,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10, bias=True),
+            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
+            reference_fn=no_batch_dim_reference_lstmcell,
+        ),
+    )
+
+    return samples
+
+def make_packed_sequence(inp, batch_sizes):
+    required_grad = inp.requires_grad
+    inp.requires_grad_(False)  # user won't have access to inp so won't be able to get its grads
+    seq = pack_padded_sequence(inp, batch_sizes)
+    seq.data.requires_grad_(required_grad)
+    return seq
+
+
+def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    is_rnn = kwargs['is_rnn']
+    nonlinearity = ('relu', 'tanh')
+    bias = (False, True)
+    batch_first = (False, True)
+    bidirectional = (False, True)
+
+    samples = []
+    if is_rnn:
+        prod_gen = product(nonlinearity, bias, batch_first, bidirectional)
+    else:
+        prod_gen = product(bias, batch_first, bidirectional)
+
+    for args in prod_gen:
+        if is_rnn:
+            nl, b, b_f, bidir = args
+        else:
+            b, b_f, bidir = args
+
+        cons_args = {'input_size': 2, 'hidden_size': 2, 'num_layers': 2,
+                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+        cons_args_hidden = {'input_size': 2, 'hidden_size': 3, 'num_layers': 2,
+                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+
+        if is_rnn:
+            cons_args['nonlinearity'] = nl
+            cons_args_hidden['nonlinearity'] = nl
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args),
+                forward_input=FunctionInput(make_input((3, 2))),
+                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+            )
+        )
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args_hidden),
+                forward_input=FunctionInput(make_input((3, 2)), make_input((4 if bidir else 2, 3))),
+                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+            )
+        )
+        if with_packed_sequence:
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(**cons_args),
+                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))),
+                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+                )
+            )
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(**cons_args),
+                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))),
+                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+                )
+            )
+
+    return samples
+
+
+def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    bias = (False, True)
+    batch_first = (False, True)
+    bidirectional = (False, True)
+    proj_sizes = (0, 2)
+
+    samples = []
+    prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
+
+    for args in prod_gen:
+        b, b_f, bidir, proj_size = args
+        hidden_size = 3
+        cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
+                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+        cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
+                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args),
+                forward_input=FunctionInput(make_input((2, 2))),
+                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
+            )
+        )
+
+        h_out = proj_size if proj_size > 0 else hidden_size
+        hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args_hidden),
+                forward_input=FunctionInput(make_input((3, 2)), hx),
+                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
+            )
+        )
+
+
+    return samples
+
+
+
+def module_inputs_torch_nn_ReflectionPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((2, 3))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReflectionPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReflectionPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
+            forward_input=FunctionInput(make_input((3, 3, 3, 3, 3))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6, 7))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 2),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2), 3),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4), 5),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6), 7),
+            forward_input=FunctionInput(make_input((1, 2, 1, 2, 1))),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def padding1d_circular_ref(inp, pad):
+        r""" input:
+                [[[0., 1., 2.],
+                  [3., 4., 5.]]]
+                pad: (1, 2)
+                output:
+                    [[[2., 0., 1., 2., 0., 1.],
+                      [5., 3., 4., 5., 3., 4.]]]
+            """
+        return torch.cat([inp[:, :, -pad[0]:], inp, inp[:, :, :pad[1]]], dim=2)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 1)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def padding2d_circular_ref(inp, pad):
+        r"""input:
+                [[[[0., 1., 2],
+                   [3., 4., 5.]]]]
+                pad: (1, 2, 2, 1)
+        output:
+            [[[[2., 0., 1., 2., 0., 1.],
+               [5., 3., 4., 5., 3., 4.],
+               [2., 0., 1., 2., 0., 1.],
+               [5., 3., 4., 5., 3., 4.],
+               [2., 0., 1., 2., 0., 1.]]]]
+        """
+        inp = torch.cat([inp[:, :, -pad[2]:], inp, inp[:, :, :pad[3]]], dim=2)
+        return torch.cat([inp[:, :, :, -pad[0]:], inp, inp[:, :, :, :pad[1]]], dim=3)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 2, 1)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3, 2, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3, 3, 1)),
+            forward_input=FunctionInput(make_input((1, 1, 3, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+
+    def padding3d_circular_ref(inp, pad):
+        r"""input:
+                [[[[[ 0.,  1.,  2.],
+                    [ 3.,  4.,  5.]],
+                   [[ 6.,  7.,  8.],
+                    [ 9., 10., 11.]]]]]
+            pad: (1, 2, 2, 1, 1, 2)
+            output: [[[[[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
+
+                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
+
+                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
+
+                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
+
+                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]]]]]
+        """
+        inp = torch.cat([inp[:, :, -pad[4]:], inp, inp[:, :, :pad[5]]], dim=2)
+        inp = torch.cat([inp[:, :, :, -pad[2]:], inp, inp[:, :, :, :pad[3]]], dim=3)
+        return torch.cat([inp[:, :, :, :, -pad[0]:], inp, inp[:, :, :, :, :pad[1]]], dim=4)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 2, 2, 1, 1, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3, 2, 1, 2, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+    ]
+
+
+# All these operators share similar issues on cuDNN and MIOpen
+rnn_gru_lstm_module_info_decorators = (
+    # RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
+    # We could not generate a fallback
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_grad",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
+    # Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_gradgrad",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # CUDNN GRU doesn't accept non-contiguous hx
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
+        active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
+    )
+)
+
+# Start of module error inputs functions.
+
+def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="input has inconsistent input_size: got 11 expected 10"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(5, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 1, 1, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="Expected hidden to be 1D or 2D, got 4D instead"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20, 'relu'),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20, 'tanh'),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+    ]
+    return samples
+
+def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 11), (make_input(3, 20), make_input(3, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="input has inconsistent input_size: got 11 expected 10"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 21), make_input(3, 21))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(5, 20), make_input(5, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 1, 1, 20), make_input(3, 1, 1, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="Expected hx\\[0\\] to be 1D or 2D, got 4D instead"
+        ),
+    ]
+    return samples
+
+
+def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(constructor_input=FunctionInput(10, 0, 1)),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex="hidden_size must be greater than zero"
+        ),
+        ErrorModuleInput(
+            ModuleInput(constructor_input=FunctionInput(10, 10, 0)),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex="num_layers must be greater than zero"
+        ),
+    ]
+    return samples
+
+def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 2D or 3D input \(got 4D input\)",
+
+        ),
+    ]
+
+def module_error_inputs_torch_nn_Pad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 3D or 4D input \(got 2D input\)",
+
+        ),
+    ]
+
+def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 4D or 5D input \(got 2D input\)",
+
+        ),
+    ]
+
+
+_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0)
+
+
+# Database of ModuleInfo entries in alphabetical order.
+module_db: list[ModuleInfo] = [
+    ModuleInfo(torch.nn.AdaptiveAvgPool1d,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
+               skips=(
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveAvgPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d,
+               skips=(
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),
+                   # Fails on backward check if output size is 1x1
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                   ),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveAvgPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool1d,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d,
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool2d,
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AvgPool1d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool1d,
+               ),
+    ModuleInfo(torch.nn.AvgPool2d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool2d,
+               skips=(
+                   # The difference between channels last backward and
+                   # channels first backward of AvgPool2d on CUDA is too large
+                   # See https://github.com/pytorch/pytorch/issues/107201
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='cuda',),
+               ),),
+    ModuleInfo(torch.nn.AvgPool3d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # No channels_last support for AvgPool1d as it does not take 4D inputs
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # backward not supported on MPS backend
+                   DecorateInfo(skipMPS, 'TestModule', 'test_non_contiguous_tensors'),)
+               ),
+    ModuleInfo(torch.nn.BatchNorm1d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
+               module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
+               skips=(
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ))
+               ),
+    ModuleInfo(torch.nn.BatchNorm2d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
+               module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
+               skips=(
+                   # See https://github.com/pytorch/pytorch/issues/134580
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')),
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),)
+               ),
+    ModuleInfo(torch.nn.BatchNorm3d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
+               module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),)
+               ),
+    ModuleInfo(torch.nn.CELU,
+               module_inputs_func=module_inputs_torch_nn_CELU,
+               # not MPS specific, will be xfailed for all devices in next PR
+               skips=(
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.Conv1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Conv2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='cuda', dtypes=[torch.float64]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Conv3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Conv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS, device_type="mps"),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               skips=(
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               skips=(
+                   # Fails on backward check because ViewAsRealBackward apply contiguous for grad
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
+                                dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.float64, torch.complex128]),
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # ConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.CosineEmbeddingLoss,
+               module_inputs_func=module_inputs_torch_nn_CosineEmbeddingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.ELU,
+               module_inputs_func=module_inputs_torch_nn_ELU,
+               # not MPS specific, will be xfailed for all devices in next PR
+               skips=(
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.FractionalMaxPool2d,
+               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.FractionalMaxPool3d,
+               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.L1Loss,
+               module_inputs_func=module_inputs_torch_nn_L1Loss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.SmoothL1Loss,
+               module_inputs_func=module_inputs_torch_nn_SmoothL1Loss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   # NS: Still fails on MacOS15.1
+                   DecorateInfo(skipIfMPS, 'TestModule', 'test_non_contiguous_tensors',
+                                dtypes=[torch.float16], device_type='mps'),),
+               ),
+    ModuleInfo(torch.nn.LazyConv1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConv2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='cuda', dtypes=[torch.float64]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConv3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # LazyConv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.float64]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # LazyConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Linear,
+               module_inputs_func=module_inputs_torch_nn_Linear,
+               skips=(
+                   # No channels_last support for Linear currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Bilinear,
+               module_inputs_func=module_inputs_torch_nn_Bilinear,
+               decorators=[
+                   DecorateInfo(
+                       toleranceOverride({
+                           torch.float32: tol(atol=1e-4, rtol=1e-4),
+                           torch.float64: tol(atol=1e-4, rtol=1e-4)}),
+                       'TestModule', 'test_forward', device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for Bilinear currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.LPPool1d,
+               module_inputs_func=module_inputs_torch_nn_LPPool1d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.LPPool2d,
+               module_inputs_func=module_inputs_torch_nn_LPPool2d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training') and not _macos15_or_newer,
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LPPool3d,
+               module_inputs_func=module_inputs_torch_nn_LPPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(skipIfMPS, device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.MaxPool1d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool1d,
+               ),
+    ModuleInfo(torch.nn.MaxPool2d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool2d,
+               ),
+    ModuleInfo(torch.nn.MaxPool3d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               ),
+    ModuleInfo(torch.nn.KLDivLoss,
+               module_inputs_func=module_inputs_torch_nn_KLDivLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # https://github.com/pytorch/pytorch/issues/115588
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.MSELoss,
+               module_inputs_func=module_inputs_torch_nn_MSELoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.MarginRankingLoss,
+               module_inputs_func=module_inputs_torch_nn_MarginRankingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.MultiLabelMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiLabelMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # 'aten::multilabel_margin_loss_forward' is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'),
+                   # derivative for aten::multilabel_margin_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.MultiMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # 'aten::multi_margin_loss' is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'),
+                   # RuntimeError: derivative for aten::multi_margin_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.SoftMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_SoftMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.MultiLabelSoftMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiLabelSoftMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.NLLLoss,
+               module_inputs_func=module_inputs_torch_nn_NLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.GaussianNLLLoss,
+               module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
+    ModuleInfo(torch.nn.PoissonNLLLoss,
+               module_inputs_func=module_inputs_torch_nn_PoissonNLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
+    ModuleInfo(torch.nn.HingeEmbeddingLoss,
+               module_inputs_func=module_inputs_torch_nn_HingeEmbeddingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.HuberLoss,
+               module_inputs_func=module_inputs_torch_nn_HuberLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: seemingly incorrect output dtype
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.BCELoss,
+               module_inputs_func=module_inputs_torch_nn_BCELoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # error: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.BCEWithLogitsLoss,
+               module_inputs_func=module_inputs_torch_nn_BCEWithLogitsLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # see #119108: tolerance issue
+                   DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.CrossEntropyLoss,
+               module_inputs_func=module_inputs_torch_nn_CrossEntropyLoss,
+               dtypes=get_all_fp_dtypes(include_half=True, include_bfloat16=False),
+               decorators=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'),
+                   DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule",
+                                "test_forward", dtypes=[torch.float16], device_type='cpu'),
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16],
+                                device_type='cuda'),),
+               ),
+    ModuleInfo(torch.nn.CTCLoss,
+               module_inputs_func=module_inputs_torch_nn_CTCLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # The operator aten::_ctc_loss is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps',),
+                   # derivative for aten::_ctc_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   # https://github.com/pytorch/pytorch/issues/115585
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_non_contiguous_tensors'),)
+               ),
+    ModuleInfo(torch.nn.GELU,
+               module_inputs_func=module_inputs_torch_nn_GELU,
+               skips=(
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.GLU,
+               module_inputs_func=module_inputs_torch_nn_GLU,
+               ),
+    ModuleInfo(torch.nn.GroupNorm,
+               module_inputs_func=module_inputs_torch_nn_GroupNorm,
+               module_error_inputs_func=module_error_inputs_torch_nn_GroupNorm,
+               dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True),
+               skips=(
+                   # Tracking at https://github.com/pytorch/pytorch/issues/98089
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_memory_format', device_type='cpu'),
+                   # No channels_last support for GroupNorm currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='mps'),
+                   DecorateInfo(unittest.skip("Skipped!"), "TestModule", "test_grad",
+                                active_if=TEST_WITH_ROCM, device_type='cuda'),)
+               ),
+    ModuleInfo(torch.nn.Hardshrink,
+               module_inputs_func=module_inputs_torch_nn_Hardshrink,
+               ),
+    ModuleInfo(torch.nn.Hardswish,
+               module_inputs_func=module_inputs_torch_nn_Hardswish,
+               supports_gradgrad=False),
+    ModuleInfo(torch.nn.Hardtanh,
+               module_inputs_func=module_inputs_torch_nn_Hardtanh,
+               ),
+    ModuleInfo(torch.nn.InstanceNorm1d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=1),
+               train_and_eval_differ=True,
+               skips=(
+                   # No channels_last support for InstanceNorm1d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.InstanceNorm2d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=2),
+               train_and_eval_differ=True,
+               skips=(
+                   # No channels_last support for InstanceNorm2d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.InstanceNorm3d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3),
+               train_and_eval_differ=True,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_memory_format'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous_tensors'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_forward'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_save_load'),
+                   # No channels_last support for InstanceNorm3d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.LocalResponseNorm,
+               module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
+               ),
+    ModuleInfo(torch.nn.LayerNorm,
+               module_inputs_func=module_inputs_torch_nn_LayerNorm,
+               skips=(
+                   # No channels_last support for LayerNorm currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.RMSNorm,
+               module_inputs_func=module_inputs_torch_nn_RMSNorm,
+               ),
+    # TransformerEncoder takes the same inputs as TransformerEncoderLayer
+    ModuleInfo(torch.nn.TransformerEncoder,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_TransformerEncoder,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerEncoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # Doesn't support device / dtype kwargs directly because it is just a
+                   # container of TransformerEncoderLayers.
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_factory_kwargs'),)
+               ),
+    ModuleInfo(torch.nn.TransformerEncoderLayer,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
+               decorators=[
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_non_contiguous_tensors',
+                                device_type='cpu', active_if=IS_WINDOWS),
+                   DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}),
+                                'TestModule', 'test_forward',
+                                device_type='mps'),
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerEncoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.TransformerDecoderLayer,
+               module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerDecoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Transformer,
+               module_inputs_func=module_inputs_torch_nn_Transformer,
+               # Inputs are too large to run with slow gradcheck
+               # https://github.com/pytorch/pytorch/issues/117140
+               gradcheck_fast_mode=True,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for Transformer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.MultiheadAttention,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
+               skips=(
+                   # No channels_last support for MultiheadAttention currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Embedding,
+               module_inputs_func=module_inputs_torch_nn_Embedding,
+               decorators=[
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_non_contiguous_tensors',
+                                device_type='mps')],
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.ReLU,
+               module_inputs_func=module_inputs_torch_nn_ReLU,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LeakyReLU,
+               module_inputs_func=module_inputs_torch_nn_LeakyReLU,
+               ),
+    ModuleInfo(torch.nn.ReLU6,
+               module_inputs_func=module_inputs_torch_nn_ReLU6,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.PReLU,
+               module_inputs_func=module_inputs_torch_nn_PReLU,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.RNNCell,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
+               ),
+    ModuleInfo(torch.nn.GRUCell,
+               module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell,
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
+               ),
+    ModuleInfo(torch.nn.LSTMCell,
+               module_inputs_func=module_inputs_torch_nn_LSTMCell,
+               module_error_inputs_func=module_error_inputs_torch_nn_LSTMCell,
+               ),
+    ModuleInfo(torch.nn.Sigmoid,
+               module_inputs_func=module_inputs_torch_nn_Sigmoid,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LogSigmoid,
+               module_inputs_func=module_inputs_torch_nn_LogSigmoid,
+               skips=(
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.SiLU,
+               module_inputs_func=module_inputs_torch_nn_SiLU,
+               ),
+    ModuleInfo(torch.nn.Softmax,
+               module_inputs_func=module_inputs_torch_nn_Softmax,
+               ),
+    ModuleInfo(torch.nn.Softmax2d,
+               module_inputs_func=module_inputs_torch_nn_Softmax2d,
+               skips=(
+                   # no channels last support for Softmax2d currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.LogSoftmax,
+               module_inputs_func=module_inputs_torch_nn_LogSoftmax,
+               skips=(
+                   # no channels last support for LogSoftmax currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: inf nan error
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.Softmin,
+               module_inputs_func=module_inputs_torch_nn_Softmin,
+               skips=(
+                   # no channels last support for Softmin currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Softplus,
+               module_inputs_func=module_inputs_torch_nn_Softplus,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Softshrink,
+               module_inputs_func=module_inputs_torch_nn_Softshrink,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Softsign,
+               module_inputs_func=module_inputs_torch_nn_Softsign,
+               ),
+    ModuleInfo(torch.nn.Tanh,
+               module_inputs_func=module_inputs_torch_nn_Tanh,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.Tanhshrink,
+               module_inputs_func=module_inputs_torch_nn_Tanhshrink,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.Threshold,
+               module_inputs_func=module_inputs_torch_nn_Threshold,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Mish,
+               module_inputs_func=module_inputs_torch_nn_Mish,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.RNN,
+               train_and_eval_differ=True,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               decorators=rnn_gru_lstm_module_info_decorators
+               ),
+    ModuleInfo(torch.nn.GRU,
+               train_and_eval_differ=True,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               decorators=rnn_gru_lstm_module_info_decorators),
+    ModuleInfo(torch.nn.LSTM,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_LSTM,
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               skips=(
+                   # LSTM with projections is not currently supported with MPS
+                   DecorateInfo(skipMPS),),
+               decorators=rnn_gru_lstm_module_info_decorators),
+    ModuleInfo(torch.nn.ReflectionPad1d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
+               ),
+    ModuleInfo(torch.nn.ReflectionPad2d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReflectionPad3d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReplicationPad1d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad1d,
+               ),
+    ModuleInfo(torch.nn.ReplicationPad2d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReplicationPad3d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.SELU,
+               module_inputs_func=module_inputs_torch_nn_SELU,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.ZeroPad1d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad1d,
+               ),
+    ModuleInfo(torch.nn.ZeroPad2d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad2d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ZeroPad3d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.CircularPad1d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad1d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad1d,
+               ),
+    ModuleInfo(torch.nn.CircularPad2d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad2d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad2d,
+               ),
+    ModuleInfo(torch.nn.CircularPad3d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad3d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),)
+               ),
+    ModuleInfo(torch.nn.ConstantPad1d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad1d,
+               ),
+    ModuleInfo(torch.nn.ConstantPad2d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad2d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ConstantPad3d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               )
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedd0c92b6a4da6d7a0e1d30efa3551c05e11208
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py
@@ -0,0 +1,840 @@
+import unittest
+from collections.abc import Sequence
+from typing import Optional
+
+import torch
+
+from .common_utils import MACOS_VERSION
+from .opinfo.core import DecorateInfo, OpInfo
+
+
+if torch.backends.mps.is_available():
+
+    def mps_ops_modifier(
+        ops: Sequence[OpInfo],
+        device_type: str = "mps",
+        xfail_exclusion: Optional[list[str]] = None,
+        sparse: bool = False,
+    ) -> Sequence[OpInfo]:
+        if xfail_exclusion is None:
+            xfail_exclusion = []
+
+        # Supported complex OPS
+        SUPPORTED_COMPLEX_OPS = {
+            "__radd__",
+            "__rmul__",
+            "__rsub__",
+            "__getitem__",
+            "_unsafe_masked_index",
+            "_unsafe_masked_index_put_accumulate",
+            "abs",
+            "add",
+            "alias_copy",
+            "argwhere",
+            "atleast_1d",
+            "atleast_2d",
+            "atleast_3d",
+            "as_strided",
+            "as_strided_copy",
+            "as_strided_scatter",
+            "asin",
+            "asinh",
+            "acos",
+            "atan",
+            "broadcast_tensors",
+            "broadcast_to",
+            "chalf",
+            "cfloat",
+            "chunk",
+            "clone",
+            "conj",
+            "conj_physical",
+            "contiguous",
+            "cos",
+            "cosh",
+            "diag",
+            "diag_embed",
+            "diagflat",
+            "diagonal",
+            "diagonal_copy",
+            "diagonal_scatter",
+            "divno_rounding_mode",
+            "dsplit",
+            "empty",
+            "empty_permuted",
+            "empty_strided",
+            "exp",
+            "expm1",
+            "exp2",
+            "expand",
+            "expand_as",
+            "expand_copy",
+            "flatten",
+            "fill",
+            "full",
+            "full_like",
+            "H",
+            "hsplit",
+            "imag",
+            "index_add",
+            "index_copy",
+            "index_select",
+            "index_put",
+            "isfinite",
+            "isinf",
+            "isreal",
+            "item",
+            "kron",
+            "linalg.diagonal",
+            "linalg.householder_product",
+            "linalg.svd",
+            "log10",
+            "log1p",
+            "log2",
+            "log",
+            "logaddexp",
+            "logaddexp2",
+            "mH",
+            "mT",
+            "masked_fill",
+            "masked_scatter",
+            "masked_select",
+            "meshgridlist_of_tensors",
+            "meshgridvariadic_tensors",
+            "movedim",
+            "mul",
+            "narrow",
+            "narrow_copy",
+            "neg",
+            "new_full",
+            "new_ones",
+            "new_zeros",
+            "nn.functional.conv1d",
+            "nn.functional.conv2d",
+            "nn.functional.conv_transpose1d",
+            "nn.functional.conv_transpose2d",
+            "nn.functional.conv_transpose3d",
+            "nn.functional.feature_alpha_dropoutwithout_train",
+            "nn.functional.padcircular",
+            "nn.functional.softsign",
+            "nn.functional.tanhshrink",
+            "nn.functional.unfold",
+            "nonzero",
+            "ones",
+            "ones_like",
+            "outer",
+            "permute",
+            "permute_copy",
+            "positive",
+            "randn",
+            "ravel",
+            "real",
+            "repeat_interleave",
+            "reshape_as",
+            "reshape",
+            "resolve_conj",
+            "resolve_neg",
+            "rsqrt",
+            "rsub",
+            "scalar_tensor",
+            "select",
+            "sgn",
+            "sigmoid",
+            "sin",
+            "sinc",
+            "sinh",
+            "slice",
+            "special.spherical_bessel_j0",
+            "special.entr",
+            "special.xlog1py",
+            "special.zeta",
+            "split",
+            "split_with_sizes",
+            "split_with_sizes_copy",
+            "splitlist_args",
+            "sqrt",
+            "squeeze",
+            "squeeze_copy",
+            "squeezemultiple",
+            "sub",
+            "svd",
+            "t",
+            "t_copy",
+            "tanh",
+            "tan",
+            "tensor_split",
+            "transpose",
+            "transpose_copy",
+            "tril",
+            "triu",
+            "true_divide",
+            "T",
+            "unbind",
+            "unbind_copy",
+            "unflatten",
+            "unfold",
+            "unfold_copy",
+            "unsafe_chunk",
+            "unsafe_split",
+            "unsqueeze",
+            "unsqueeze_copy",
+            "view_as",
+            "view_as_real",
+            "view",
+            "view_copy",
+            "vsplit",
+            "zero_",
+            "zeros",
+            "zeros_like",
+            "__rdiv__",
+            "__rmatmul__",
+            "_chunk_cat",
+            "acosh",
+            "all",
+            "allclose",
+            "angle",
+            "any",
+            "addcdiv",
+            "addcmul",
+            "addmmdecomposed",
+            "addmv",
+            "atanh",
+            "bfloat16",
+            "bmm",
+            "bool",
+            "cartesian_prod",
+            "cat",
+            "char",
+            "column_stack",
+            "combinations",
+            "corrcoef",
+            "constant_pad_nd",
+            "cov",
+            "count_nonzero",
+            "diff",
+            "div",
+            "dot",
+            "dstack",
+            "einsum",
+            "eq",
+            "equal",
+            "eye",
+            "fft.fft",
+            "fft.fft2",
+            "fft.fftn",
+            "fft.fftshift",
+            "fft.ifft",
+            "fft.ifft2",
+            "fft.ifftn",
+            "fft.ifftshift",
+            "fft.irfftn",
+            "fft.irfft2",
+            "fft.irfft",
+            "fft.hfftn",
+            "fft.hfft2",
+            "fft.hfft",
+            "flip",
+            "fliplr",
+            "flipud",
+            "float",
+            "gradient",
+            "half",
+            "hstack",
+            "inner",
+            "int",
+            "isclose",
+            "isnan",
+            "ldexp",
+            "lerp",
+            "linalg.multi_dot",
+            "linalg.pinv",
+            "linspace",
+            "linspacetensor_overload",
+            "logical_and",
+            "logical_not",
+            "logical_or",
+            "logical_xor",
+            "logsumexp",
+            "long",
+            "masked.mean",
+            "masked.prod",
+            "masked.std",
+            "masked.sum",
+            "masked.var",
+            "masked.logsumexp",
+            "matmul",
+            "mean",
+            "mm",
+            "mv",
+            "ne",
+            "nn.functional.padconstant",
+            "nn.functional.padreflect",
+            "nn.functional.padreplicate",
+            "nn.functional.pixel_shuffle",
+            "nn.functional.pixel_unshuffle",
+            "nn.functional.rms_norm",
+            "pinverse",
+            "prod",
+            "reciprocal",
+            "roll",
+            "rot90",
+            "short",
+            "square",
+            "stack",
+            "stft",
+            "sum",
+            "sum_to_size",
+            "tensordot",
+            "trace",
+            "trapz",
+            "trapezoid",
+            "vstack",
+            "where",
+            "byte",
+        }
+
+        MACOS_BEFORE_14_4_XFAILLIST = {
+            # These ops work fine in 14.4 but fail in 14.2 or 13.x
+            "fft.hfft2": [torch.complex64],
+        }
+
+        # Those ops are not expected to work
+        UNIMPLEMENTED_XFAILLIST: dict[str, Optional[list]] = {
+            # Failures due to lack of op implementation on MPS backend
+            "logspace": None,
+            "logspacetensor_overload": None,
+            "linalg.eig": None,
+            "linalg.eigvals": None,
+            "put": None,
+            "cauchy_": None,
+            "cauchy": None,
+            "cholesky_inverse": None,
+            "cholesky_solve": None,
+            "frexp": None,
+            "gcd": None,
+            "geqrf": None,
+            "nn.functional.grid_sample": None,  # Unsupported Border padding mode
+            "hash_tensor": None,
+            "heaviside": None,
+            "index_reduceprod": None,
+            "index_reducemean": None,
+            "index_reduceamax": None,
+            "index_reduceamin": None,
+            # "kthvalue": None,
+            "lcm": None,
+            "linalg.cond": None,
+            "linalg.eigh": None,
+            "linalg.eigvalsh": None,
+            "linalg.ldl_factor": None,
+            "linalg.ldl_factor_ex": None,
+            "linalg.ldl_solve": None,
+            "linalg.lstsq": None,
+            "linalg.lstsqgrad_oriented": None,
+            "linalg.matrix_norm": [torch.float32],
+            "linalg.norm": [torch.float32],
+            "linalg.normsubgradients_at_zero": [torch.float32],
+            "linalg.qr": None,
+            "linalg.svdvals": None,
+            "linalg.vecdot": None,
+            "masked.median": None,
+            "matrix_exp": None,
+            "mode": None,
+            "normnuc": None,
+            "nn.functional.fractional_max_pool2d": None,
+            "nn.functional.fractional_max_pool3d": None,
+            "nn.functional.adaptive_avg_pool3d": None,
+            "nn.functional.adaptive_max_pool3d": None,
+            "nn.functional.interpolatearea": None,
+            "nn.functional.interpolatebicubic": [torch.uint8],
+            "nn.functional.ctc_loss": None,
+            "nn.functional.multi_margin_loss": None,
+            "nn.functional.multilabel_margin_loss": None,
+            "nn.functional.pdist": None,
+            "nn.functional.rrelu": None,
+            "nn.functional.norm": None,
+            "ormqr": None,
+            "pca_lowrank": None,
+            "qr": None,
+            "scatter_reduceamax": [torch.int32, torch.int64]
+            if MACOS_VERSION < 15.0
+            else [torch.int64],
+            "scatter_reduceamin": [torch.int32, torch.int64]
+            if MACOS_VERSION < 15.0
+            else [torch.int64],
+            "segment_reduce": None,
+            "_segment.reduce": None,
+            "segment.reduce": None,
+            "segment_reduce_offsets": None,
+            "_segment_reduce_offsets": None,
+            "_segment_reduce_lengths": None,
+            "_segment_reducelengths": None,
+            "_segment_reduceoffsets": None,
+            "sparse.mm": None,
+            "sparse.sampled_addmm": None,
+            "sparse.mmreduce": None,
+            "special.airy_ai": None,
+            "special.erfcx": None,
+            "special.laguerre_polynomial_l": None,
+            "special.legendre_polynomial_p": None,
+            "special.log_ndtr": None,
+            "special.ndtri": None,
+            "svd_lowrank": None,
+            "symeig": None,
+            "take": None,
+            "to": None,
+            "vdot": None,
+            "segment_reduce_": None,
+            "_upsample_bilinear2d_aa": [torch.uint8],  # uint8 is for CPU only
+            "_upsample_bicubic2d_aa": [torch.uint8],  # uint8 is for CPU only
+            "geometric": None,
+            "geometric_": None,
+            "log_normal_": None,
+            "log_normal": None,
+            "cdouble": None,
+            "double": None,
+            "nn.functional.softminwith_dtype": None,
+            "log_softmaxwith_dtype": None,
+            "softmaxwith_dtype": None,
+            "float_power": None,
+            "linalg.matrix_rankhermitian": None,
+            "linalg.pinvhermitian": None,
+            "nonzero_static": None,
+            # MPS: input sizes must be divisible by output sizes
+            "nn.functional.adaptive_avg_pool1d": None,
+            "nn.functional.adaptive_avg_pool2d": None,
+            # Convolution for integral types is not supported on MPS
+            "nn.functional.conv1d": [torch.int64],
+            "nn.functional.conv2d": [torch.int64],
+            "nn.functional.conv3d": [torch.int64],
+            "nn.functional.conv_transpose1d": [torch.int64],
+            "nn.functional.conv_transpose2d": [torch.int64, torch.bfloat16],
+            "nn.functional.conv_transpose3d": [
+                torch.int64,
+                torch.bfloat16,
+                torch.float16,
+            ],
+            # Unsupported dtypes
+            "histc": [torch.float16, torch.bfloat16],
+            # GEMM on MPS is not supported for integral types
+            "nn.functional.linear": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            # returned output on CPU is float64
+            "bincount": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+        }
+        UNIMPLEMENTED_XFAILLIST_SPARSE: dict[str, Optional[list]] = {
+            "logspace": None,
+            "logspacetensor_overload": None,
+            "linalg.eig": None,
+            "linalg.eigvals": None,
+            "put": None,
+        }
+
+        if MACOS_VERSION < 15.0:
+            UNIMPLEMENTED_XFAILLIST.update(
+                {
+                    "quantile": None,
+                    "nanquantile": None,
+                }
+            )
+        if sparse:
+            UNIMPLEMENTED_XFAILLIST.update(UNIMPLEMENTED_XFAILLIST_SPARSE)
+
+        UNDEFINED_XFAILLIST: dict[str, Optional[list]] = {
+            # Top 60 operators
+            # topk fails with duplicate indices
+            "topk": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            # Failures due to random output that they generate using
+            # Philox engine causing mismatch with CPU results
+            "multinomial": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],  # random results
+            "uniform": [torch.float16, torch.float32, torch.bfloat16],
+            "rand_like": [torch.float16, torch.float32, torch.bfloat16],
+            "randint": None,
+            "randint_like": None,
+            "randn": None,
+            "randn_like": None,
+            "bernoulli": [torch.float16, torch.float32, torch.bfloat16],
+            "exponential": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.feature_alpha_dropoutwith_train": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],
+            "normal": [torch.float16, torch.float32, torch.bfloat16],
+            "normalin_place": [torch.float16, torch.float32, torch.bfloat16],
+            "normalnumber_mean": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.alpha_dropout": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],
+            "nn.functional.dropout": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.dropout2d": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.dropout3d": [torch.float16, torch.float32, torch.bfloat16],
+            # See https://github.com/pytorch/pytorch/issues/111479
+            "nn.functional.multi_head_attention_forward": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],
+            # zero to negative integer powers are undefined
+            "__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
+            "resize_": [torch.float16, torch.float32, torch.bfloat16],
+            "resize_as_": [torch.float16, torch.float32, torch.bfloat16],
+            # CPU Errors:
+            "addr": [
+                torch.bool,
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],  # "addmv_impl_cpu" not implemented for 'Half'
+            "as_stridedpartial_views": None,  # cpu result off, showing random values
+            # random results
+            # mps vs cpu:
+            # Mismatched elements: 40 / 96 (41.7%)
+            # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
+            # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
+            # cuda(2.0.0.dev20230301+cu117) vs cpu:
+            # Mismatched elements: 56 / 96 (58.3%)
+            # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
+            # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
+            "nn.functional.scaled_dot_product_attention": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],
+        }
+
+        ON_MPS_XFAILLIST: dict[str, Optional[list]] = {
+            # Failures due to lack of implementation of downstream functions on MPS backend
+            # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+            "linalg.matrix_rank": None,
+            # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")`
+            "arange": [torch.uint8],
+            # before macOS 13.2 it falls back to cpu and pass the forward pass
+            "grid_sampler_2d": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],  # Unsupported Border padding mode
+            # Failure due to precision issue for fp16
+            # on both cpu and mps there are test cases that might produce inf result
+            # 'nn.functional.pairwise_distance': [torch.float16],
+            # test blow pass on macOS 12 as it falls back to cpu
+            # Argsort case using duplicate indices (undefined behaviour):
+            #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], device='cpu')
+            #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
+            # Elements from index 30 and 5133 are both equal.
+            # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
+            "argsort": [
+                torch.float16,
+                torch.int8,
+                torch.uint8,
+                torch.bool,
+                torch.bfloat16,
+            ],
+            # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
+            # The values of the sorted tensor match the CPU,
+            # but in case of the returned indices this results in undefined behaviour.
+            "sort": [
+                torch.int8,
+                torch.uint8,
+                torch.bool,
+                torch.float16,
+                torch.bfloat16,
+            ],
+        }
+
+        EMPTY_OPS_SKIPLIST = {
+            # Fill tensors with uninitialized data, causing mismatch with CPU.
+            # They occasionally match, thus skipping them.
+            # See https://github.com/pytorch/pytorch/issues/100175
+            "new_empty": None,
+            "new_empty_strided": None,
+            "empty_strided": None,
+            # CPU: empty is returning all 0's and there is a mismatch with MPS
+            # allocation (MacOS 13). According to
+            # https://pytorch.org/docs/2.0/generated/torch.empty.html
+            "empty": None,
+            "empty_like": None,
+            "empty_permuted": None,
+        }
+
+        SKIPLIST = {
+            # Unsupported
+            # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
+            "nn.functional.conv3d": None,
+            # The CPU impl of grid_sampler_3d does not use opmath_t, so it has a
+            # large amount of error compared with the MPS impl for half
+            # precision types. So we have to skip these for now.
+            "grid_sampler_3d": [torch.float16, torch.bfloat16],
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            if device_type is not None:
+                d.device_type = device_type
+
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            addDecorator(
+                op,
+                DecorateInfo(
+                    unittest.expectedFailure,
+                    dtypes=[
+                        torch.double,
+                        torch.cdouble,
+                    ],
+                ),
+            )
+            if sparse:
+                # Skipped due to test_sparse_zero_dims test in test_sparse.py which allocates empty tensor
+                # which leads to unexpected success with it
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.skip(
+                            "Skipped due to MPS not supporting complex128 tensors"
+                        ),
+                        dtypes=[
+                            torch.complex128,
+                        ],
+                    ),
+                )
+            if key in EMPTY_OPS_SKIPLIST:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.skip("Skipping empty ops."),
+                        dtypes=EMPTY_OPS_SKIPLIST[key],
+                    ),
+                )
+            if key in SKIPLIST:
+                addDecorator(
+                    op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key])
+                )
+            for xfaillist in [
+                UNIMPLEMENTED_XFAILLIST,
+                UNDEFINED_XFAILLIST,
+                ON_MPS_XFAILLIST,
+            ]:
+                if key in xfaillist and key not in xfail_exclusion:
+                    addDecorator(
+                        op,
+                        DecorateInfo(unittest.expectedFailure, dtypes=xfaillist[key]),
+                    )
+
+            if (
+                key in MACOS_BEFORE_14_4_XFAILLIST
+                and key not in xfail_exclusion
+                and (MACOS_VERSION < 14.4)
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=MACOS_BEFORE_14_4_XFAILLIST[key],
+                    ),
+                )
+
+            # If ops is not supported for complex types, expect it to fail
+            if key not in SUPPORTED_COMPLEX_OPS:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=[torch.complex32, torch.complex64],
+                    ),
+                )
+
+        return ops
+
+    def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
+        XFAILLIST_GRAD = {
+            # Unimplemented ops
+            "_segment_reduce": [torch.float16, torch.float32],
+            "_chunk_cat": [torch.float16, torch.float32],
+            "_upsample_bilinear2d_aa": None,  # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
+            "_upsample_bicubic2d_aa": None,  # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
+            "sparse.mmreduce": [torch.float32],  # csr not supported
+            "linalg.householder_product": None,
+            "unique_consecutive": [torch.float16, torch.float32],
+            "scalar_tensor": [torch.float16, torch.float32],
+            "cdist": [torch.float32],
+            "masked.scatter": [torch.float16, torch.float32],
+            "grid_sampler_3d": None,
+            "index_fill": [torch.float16, torch.float32],  # missing `aten::_unique`.
+            "igamma": None,  # currently not supported for any device
+            "igammac": None,  # currently not supported for any device
+            "linalg.solve": [torch.float16, torch.float32],  # missing `aten::lu_solve`.
+            "linalg.solve_ex": [
+                torch.float16,
+                torch.float32,
+            ],  # missing `aten::lu_solve`.
+            "linalg.tensorsolve": [
+                torch.float16,
+                torch.float32,
+            ],  # missing `aten::lu_solve`.
+            "aminmax": [torch.float32, torch.float16],
+            "special.i1": [torch.float16],  # "i1_backward" not implemented for 'Half'
+            "special.i1e": [torch.float16],  # "i1e_backward" not implemented for 'Half'
+            # Correctness issues
+            "atanh": [torch.float32],
+            # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
+            # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
+            # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
+            # Running `msort` with stable `sort` passes.
+            "msort": [torch.float16],
+            # Random output
+            "exponential": [torch.float16, torch.float32],
+            # CPU errors
+            # derivative for zeta is not implemented
+            "special.zeta": None,
+            # derivative for aten::nextafter is not implemented on CPU
+            "nextafter": None,
+            # derivative for aten::floor_divide is not implemented on CPU
+            "floor_divide": [torch.float16, torch.float32],
+            # derivative for aten::narrow_copy is not implemented on CPU
+            "narrow_copy": [torch.float16, torch.float32],
+            # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
+            "histogramdd": [torch.float16, torch.float32],
+            # derivative for aten::histogram is not implemented
+            "histogram": [torch.float16, torch.float32],
+            # 'bool' object is not iterable
+            "allclose": [torch.float16, torch.float32],
+            "equal": [torch.float16, torch.float32],
+            # 'float' object is not iterable
+            "item": [torch.float16, torch.float32],
+            # cpu error: grad requires non-empty inputs
+            "randn": [torch.float16, torch.float32],
+            "signal.windows.bartlett": [torch.float32],
+            "signal.windows.blackman": [torch.float32],
+            "signal.windows.cosine": [torch.float32],
+            "signal.windows.exponential": [torch.float32],
+            "signal.windows.gaussian": [torch.float32],
+            "signal.windows.general_cosine": [torch.float32],
+            "signal.windows.general_hamming": [torch.float32],
+            "signal.windows.hamming": [torch.float32],
+            "signal.windows.hann": [torch.float32],
+            "signal.windows.kaiser": [torch.float32],
+            "signal.windows.nuttall": [torch.float32],
+            "eye": [torch.float16, torch.float32],
+            # topk fails with duplicate indices
+            "topk": [torch.float16],
+            # Could not run 'aten::uniform_' with arguments from the 'SparseCPU' backend
+            "to_sparse": None,
+            # Exception: the derivative for '_unique2' is not implemented.
+            "unique": None,
+        }
+
+        SKIPLIST_GRAD = {
+            "nn.functional.pairwise_distance": [torch.float16],
+            # failed assertion `destination datatype must be fp32'
+            "nn.functional.conv1d": [torch.float16],
+            "nn.functional.conv2d": [torch.float16],
+            "nn.functional.conv3d": [torch.float16],
+            "nn.functional.conv_transpose1d": [torch.float16],
+            "nn.functional.conv_transpose2d": [torch.float16],
+            "nn.functional.conv_transpose3d": [torch.float16],
+        }
+
+        ON_MPS_XFAILLIST = {
+            # Failures due to lack of implementation of downstream functions on MPS backend
+            # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+            "linalg.matrix_rank": None,
+            # Exception: Caused by sample input at index 3 on MPS
+            "nn.functional.conv3d": [torch.float32],
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            if key in XFAILLIST_GRAD:
+                addDecorator(
+                    op,
+                    DecorateInfo(unittest.expectedFailure, dtypes=XFAILLIST_GRAD[key]),
+                )
+
+            if key in SKIPLIST_GRAD:
+                addDecorator(op, DecorateInfo(unittest.skip, dtypes=SKIPLIST_GRAD[key]))
+
+            if key in ON_MPS_XFAILLIST:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure, dtypes=ON_MPS_XFAILLIST[key]
+                    ),
+                )
+
+        return ops
+
+    def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
+        # Error input samples do not take a dtype argument.
+        XFAILLIST = {
+            # Exceptions are not raised
+            "__rmod__",
+            "__rsub__",
+            "__rpow__",
+            "clamp_max",
+            "clamp_min",
+            "masked_scatter",
+            # unsupported float64 dtype
+            "multinomial",
+            "nn.functional.conv1d",
+            "nn.functional.conv2d",
+            "nn.functional.conv3d",
+            "gather",
+            "scatter",
+            "scatter_add",
+            # MPS does not support tensor dimensions > 16
+            "amax",
+            "amin",
+            "aminmax",
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            if key in XFAILLIST:
+                addDecorator(op, DecorateInfo(unittest.expectedFailure))
+
+        return ops
+else:
+
+    def mps_ops_modifier(
+        ops: Sequence[OpInfo],
+        device_type: str = "mps",
+        xfail_exclusion: Optional[list[str]] = None,
+        sparse: bool = False,
+    ) -> Sequence[OpInfo]:
+        return ops
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a276144e53bd3145590775ecb13573bda3eb12f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py
@@ -0,0 +1,3998 @@
+# mypy: ignore-errors
+
+from abc import abstractmethod
+import tempfile
+import unittest
+
+from copy import deepcopy
+from functools import reduce, partial
+from itertools import product
+from operator import mul
+
+
+import torch
+import torch.cuda
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import _reduction as _Reduction
+from torch.testing._internal import common_utils
+from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
+    gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
+from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
+from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
+from torch.autograd import Variable
+from torch.types import _TensorOrTensors
+import torch.backends.cudnn
+
+from typing import Union, Any
+from collections.abc import Callable
+from collections.abc import Sequence
+
+TemporaryFile = tempfile.TemporaryFile
+PRECISION = 1e-5
+
+
+def get_reduction(m):
+    result = getattr(m, 'reduction', None)
+    if result is None:
+        result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
+    assert result is not None
+    return result
+
+
+def get_weight(m):
+    result = getattr(m, 'weight', None)
+    if result is not None:
+        return result
+    return getattr(m, 'weights', None)
+
+# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
+#
+# The way to check API parity is to add parity tests for the NN module / functional of interest.
+# Here are the detailed steps:
+#
+# For NN module:
+# 1. Make sure you already have a test dict with the module configuration you want to test.
+# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
+#    the Python module constructor arguments. For example, if in the test dict we pass
+#    `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
+#    as the corresponding C++ constructor argument to `torch::nn::Linear`.
+# 3. If in the process of performing the above step you referenced any variables
+#    in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
+#    to the test dict to make sure that those variables are populated with the right Python values.
+#    For example, if the Python constructor call is
+#    `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
+#    the corresponding C++ constructor argument is
+#    `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
+#    and the `cpp_var_map` entry must be
+#    `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
+#    used in the C++ constructor argument with the Python tensor value `random_samples`.
+#
+# For NN functional:
+# 1. Make sure you already have a test dict with the functional configuration you want to test.
+# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
+#    then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
+#    functional optional arguments. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
+#    then the `cpp_options_args` entry should be
+#    "F::InterpolateFuncOptions().size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)".
+# 3. Otherwise, if the test dict's `constructor` entry looks like
+#    `wrap_functional(lambda i: F.some_functional_name(...))`,
+#    then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
+#    functional function call. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
+#    then the `cpp_function_call` entry should be
+#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
+# 4. If in the process of performing the above two steps you referenced any variables
+#    in the `cpp_options_args` or `cpp_function_call` entry, you must
+#    add `cpp_var_map` entry to the test dict to make sure that those variables
+#    are populated with the right Python values. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
+#    then the `cpp_function_call` entry should be
+#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
+#    Notice that there are two variables `i` and `t` that need to have their values provided,
+#    and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
+#    (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
+#    and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
+#
+# There are also a few optional flags in the test dict to control the C++ parity test behavior:
+#
+# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
+# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
+
+
+module_tests = [
+    dict(
+        module_name='Linear',
+        constructor_args=(10, 8),
+        cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
+        input_size=(4, 10),
+        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
+        with_tf32=True,
+        tf32_precision=0.005,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='Linear',
+        constructor_args=(10, 8, False),
+        cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
+        input_size=(4, 10),
+        desc='no_bias',
+        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
+        with_tf32=True,
+        tf32_precision=0.005,
+        # ROCM: skipping tf32 test on gfx94 archs due to tolerance issue.
+        test_cuda=not (TEST_WITH_ROCM and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName),
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='RReLU',
+        input_size=(1, 2, 2),
+        test_cuda=False,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='RReLU',
+        constructor_args=(0.1, 0.9),
+        cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
+        input_size=(4, 4, 5),
+        desc='with_up_down',
+        test_cuda=False,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='Flatten',
+        input_size=(2, 3, 4, 5),
+        reference_fn=lambda i, *_: torch.flatten(i, 1),
+        default_dtype=torch.double,
+    ),
+    # TODO: reference function
+    dict(
+        module_name='CrossMapLRN2d',
+        constructor_args=(5, 5e-3, 1e-3, 2),
+        cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)',
+        input_size=(2, 3, 6, 6),
+        check_gradgrad=False,
+        # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched"
+        check_batched_grad=False,
+        default_dtype=torch.double,
+    ),
+]
+
+
+# Generates rand tensor with non-equal values. This ensures that duplicate
+# values won't be causing test failure for modules like MaxPooling.
+# size should be small, otherwise randperm fails / long overflows.
+def _rand_tensor_non_equal(*size):
+    total = reduce(mul, size, 1)
+    return torch.randperm(total).view(*size).double()
+
+
+def wrap_functional(fn, **kwargs):
+    class FunctionalModule(nn.Module):
+        def forward(self, *args):
+            return fn(*args, **kwargs)
+    return FunctionalModule
+
+
+def poissonnllloss_no_reduce_test():
+    t = torch.randn(10, 10)
+    return dict(
+        fullname='PoissonNLLLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::poisson_nll_loss('
+                          'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: i.exp() - t.mul(i),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def bceloss_no_reduce_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    return dict(
+        fullname='BCELoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
+        pickle=False,
+        precision=7e-4,
+        default_dtype=torch.double)
+
+
+def bceloss_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    return dict(
+        fullname='BCELoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def bceloss_weights_no_reduce_test():
+    t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double))
+    weights = torch.rand(10, dtype=torch.double)
+    return dict(
+        fullname='BCELoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i),
+                                             weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), '
+                          'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+        pickle=False,
+        precision=3e-4,
+        default_dtype=torch.double,
+    )
+
+
+def bceloss_weights_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    weights = torch.rand((), dtype=torch.double)
+    return dict(
+        fullname='BCELoss_weights_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i),
+                                             weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy(
+            i, t.to(i.options()),
+            F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_legacy_enum_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_legacy_enum',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_no_reduce_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_with_target_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_with_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_no_reduce_scalar_test():
+    t = torch.rand((), dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(()).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_with_log_target_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_with_log_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_no_reduce_log_target_test():
+    t = torch.rand(10, 10, dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_no_reduce_log_target',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_no_reduce_scalar_log_target_test():
+    t = torch.rand((), dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_no_reduce_scalar_log_target',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(()).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def l1loss_no_reduce_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='L1Loss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def l1loss_no_reduce_complex_test():
+    t = torch.randn(2, 3, 4, dtype=torch.cdouble)
+    return dict(
+        fullname='L1Loss_no_reduce_complex',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False)
+
+
+def l1loss_no_reduce_scalar_test():
+    t = torch.randn((), dtype=torch.double)
+    return dict(
+        fullname='L1Loss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def mseloss_no_reduce_test():
+    input_size = (2, 3, 4, 5)
+    target = torch.randn(*input_size, dtype=torch.double)
+    return dict(
+        fullname='MSELoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
+        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
+        input_size=input_size,
+        cpp_var_map={'i': '_get_input()', 'target': target},
+        reference_fn=lambda i, *_: (i - target).pow(2),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def mseloss_no_reduce_scalar_test():
+    input_size = ()
+    target = torch.randn(input_size, dtype=torch.double)
+    return dict(
+        fullname='MSELoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
+        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
+        input_size=input_size,
+        cpp_var_map={'i': '_get_input()', 'target': target},
+        reference_fn=lambda i, *_: (i - target).pow(2),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_ignore_index_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_ignore_index_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none',
+                'ignore_index': 2}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''',
+        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_ignore_index_neg_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none',
+                'ignore_index': -1}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''',
+        input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss2d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_ignore_index_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss2d_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_weights_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    weight = torch.rand(3)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLoss2d_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLossNd_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_ignore_index_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLossNd_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_weights_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    weight = torch.rand(3)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLossNd_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_no_reduce_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_no_reduce_scalar_test():
+    t = torch.randn((), dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_beta_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_beta',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_zero_beta_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_zero_beta',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def huberloss_delta_test():
+    t = torch.randn(2, 3, 4)
+    return dict(
+        fullname='HuberLoss_delta',
+        constructor=wrap_functional(
+            lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)),
+        cpp_function_call='''F::huber_loss(
+            i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_0d_no_reduce_test():
+    t = torch.zeros(()).long()
+    return dict(
+        fullname='MultiLabelMarginLoss_0d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False)
+
+
+def multilabelmarginloss_1d_no_reduce_test():
+    t = Variable(torch.rand(10).mul(10).floor().long())
+    return dict(
+        fullname='MultiLabelMarginLoss_1d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_index_neg_test():
+    t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
+    return dict(
+        fullname='MultiLabelMarginLoss_index_neg',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_no_reduce_test():
+    t = Variable(torch.rand(5, 10).mul(10).floor().long())
+    return dict(
+        fullname='MultiLabelMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def hingeembeddingloss_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::hinge_embedding_loss(
+            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
+        check_sum_reduction=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def hingeembeddingloss_margin_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_margin_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
+        cpp_function_call='''F::hinge_embedding_loss(
+            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
+        check_sum_reduction=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def softmarginloss_no_reduce_test():
+    t = torch.randn(5, 5, dtype=torch.double)
+    return dict(
+        fullname='SoftMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::soft_margin_loss(
+            i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 5),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelsoftmarginloss_no_reduce_test():
+    t = torch.rand(5, 10).mul(2).floor()
+    return dict(
+        fullname='MultiLabelSoftMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::multilabel_soft_margin_loss(
+            i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelsoftmarginloss_weights_no_reduce_test():
+    t = torch.rand(5, 10).mul(2).floor()
+    weights = torch.rand(10)
+    return dict(
+        fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
+                                                    weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='''F::multilabel_soft_margin_loss(
+            i, t.to(i.options()),
+            F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, *_:
+            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_1d_no_reduce_test():
+    t = torch.rand(1).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_1d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_1d_input_0d_target_no_reduce_test():
+    t = torch.rand(()).mul(8).floor().long()
+    return dict(
+        fullname='multimarginloss_1d_input_0d_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_p_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_p_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_margin_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_margin_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
+                                                  margin=0.5, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_weights_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    weights = torch.rand(10, dtype=torch.double)
+    return dict(
+        fullname='MultiMarginLoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
+                                          reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
+                                                  weight=weights, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def single_batch_reference_fn(input, parameters, module):
+    """Reference function for modules supporting no batch dimensions.
+
+    The module is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    single_batch_input = unsqueeze_inp(input)
+    single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
+    with freeze_rng_state():
+        return module(*single_batch_input).squeeze(0)
+
+
+def get_new_module_tests():
+    common_utils.set_rng_seed()
+    new_module_tests = [
+        poissonnllloss_no_reduce_test(),
+        bceloss_no_reduce_test(),
+        bceloss_weights_no_reduce_test(),
+        bce_with_logistic_legacy_enum_test(),
+        bce_with_logistic_no_reduce_test(),
+        bceloss_no_reduce_scalar_test(),
+        bceloss_weights_no_reduce_scalar_test(),
+        bce_with_logistic_no_reduce_scalar_test(),
+        kldivloss_with_target_no_reduce_test(),
+        kldivloss_no_reduce_test(),
+        kldivloss_no_reduce_scalar_test(),
+        kldivloss_with_log_target_no_reduce_test(),
+        kldivloss_no_reduce_log_target_test(),
+        kldivloss_no_reduce_scalar_log_target_test(),
+        l1loss_no_reduce_test(),
+        l1loss_no_reduce_complex_test(),
+        l1loss_no_reduce_scalar_test(),
+        mseloss_no_reduce_test(),
+        mseloss_no_reduce_scalar_test(),
+        nllloss_no_reduce_test(),
+        nllloss_no_reduce_ignore_index_test(),
+        nllloss_no_reduce_weights_test(),
+        nllloss_no_reduce_weights_ignore_index_test(),
+        nllloss_no_reduce_weights_ignore_index_neg_test(),
+        nllloss2d_no_reduce_test(),
+        nllloss2d_no_reduce_weights_test(),
+        nllloss2d_no_reduce_ignore_index_test(),
+        nlllossNd_no_reduce_test(),
+        nlllossNd_no_reduce_weights_test(),
+        nlllossNd_no_reduce_ignore_index_test(),
+        smoothl1loss_no_reduce_test(),
+        smoothl1loss_no_reduce_scalar_test(),
+        smoothl1loss_beta_test(),
+        smoothl1loss_zero_beta_test(),
+        huberloss_delta_test(),
+        multilabelmarginloss_0d_no_reduce_test(),
+        multilabelmarginloss_1d_no_reduce_test(),
+        multilabelmarginloss_index_neg_test(),
+        multilabelmarginloss_no_reduce_test(),
+        hingeembeddingloss_no_reduce_test(),
+        hingeembeddingloss_margin_no_reduce_test(),
+        softmarginloss_no_reduce_test(),
+        multilabelsoftmarginloss_no_reduce_test(),
+        multilabelsoftmarginloss_weights_no_reduce_test(),
+        multimarginloss_no_reduce_test(),
+        multimarginloss_1d_no_reduce_test(),
+        multimarginloss_1d_input_0d_target_no_reduce_test(),
+        multimarginloss_p_no_reduce_test(),
+        multimarginloss_margin_no_reduce_test(),
+        multimarginloss_weights_no_reduce_test(),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='stride',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3, 1, 1),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='pad1',
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 5, 1, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='pad2',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 4, 3, 1, 1),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)',
+            input_size=(1, 4, 1),
+            cudnn=True,
+            desc='pad1size1',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 4, 5, 1, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)',
+            input_size=(1, 4, 1),
+            cudnn=True,
+            desc='pad2size1',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
+            input_size=(0, 4, 10),
+            cudnn=True,
+            desc='zero_batch',
+            with_tf32=True,
+            tf32_precision=0.005,
+        ),
+        dict(
+            fullname='Conv1d_dilated',
+            constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
+            input_size=(2, 4, 10),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_groups',
+            constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
+            input_size=(2, 4, 6),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_valid',
+            constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same',
+            constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same2',
+            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same_dilated',
+            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose1d',
+            constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
+            cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
+            cudnn=True,
+            input_size=(1, 3, 7),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose1d',
+            constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
+                                    .stride(2).padding(1).output_padding(1).groups(1).bias(false)''',
+            input_size=(1, 3, 6),
+            cudnn=True,
+            desc='no_bias',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose1d',
+            constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
+                                    .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''',
+            input_size=(1, 3, 6),
+            cudnn=True,
+            desc='dilated',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose1d_groups',
+            constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3)
+                                    .stride(3).padding(1).output_padding(1).groups(2)''',
+            cudnn=True,
+            input_size=(2, 4, 7),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
+            input_size=(2, 3, 7, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 3), (2, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})',
+            input_size=(2, 3, 6, 6),
+            cudnn=True,
+            desc='strided',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})',
+            input_size=(2, 3, 6, 6),
+            cudnn=True,
+            desc='padding',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})',
+            input_size=(2, 3, 8, 8),
+            cudnn=True,
+            desc='dilated',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(2, 3, 6, 5),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
+            input_size=(0, 3, 7, 5),
+            cudnn=True,
+            desc='zero_batch',
+            check_with_long_tensor=True,
+            with_tf32=True,
+        ),
+        dict(
+            fullname='Conv2d_groups',
+            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
+            input_size=(2, 4, 6, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_groups_thnn',
+            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
+            input_size=(2, 4, 6, 5),
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_valid',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_same',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_same_dilated',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({3, 2}).padding(1).output_padding({1, 1})''',
+            cudnn=True,
+            input_size=(1, 3, 7, 6),
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({2, 3})
+                                    .padding(1)
+                                    .output_padding({1, 1})
+                                    .groups(1)
+                                    .bias(false)
+                                    .dilation({2, 2})''',
+            input_size=(1, 3, 6, 7),
+            cudnn=True,
+            desc='dilated',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''',
+            input_size=(1, 3, 6, 7),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose2d_groups',
+            constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
+            cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)',
+            input_size=(1, 2, 4, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_with_multiplier',
+            constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_strided',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_padded',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_dilated',
+            constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
+            input_size=(2, 4, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (2, 3, 2)),
+            cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})',
+            input_size=(1, 2, 4, 5, 4),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            cudnn=True,
+            desc='1x1x1_no_bias',
+            check_with_long_tensor=False,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, 2, 2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)',
+            input_size=(2, 3, 5, 5, 5),
+            cudnn=True,
+            desc='stride',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, 2, 2, 1),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)',
+            input_size=(2, 3, 5, 5, 5),
+            cudnn=True,
+            desc='stride_padding',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, (2, 3, 4)),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})',
+            input_size=(0, 3, 3, 4, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            desc='zero_batch',
+            with_tf32=True,
+        ),
+        dict(
+            fullname='Conv3d_groups',
+            constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
+            input_size=(1, 2, 4, 5, 4),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_dilated',
+            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
+            input_size=(2, 3, 5, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_dilated_strided',
+            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
+            input_size=(2, 3, 5, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_valid',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_same',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_same_dilated',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose3d',
+            constructor_args=(2, 3, (2, 3, 2)),
+            cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
+            cudnn=True,
+            input_size=(1, 2, 4, 5, 4),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose3d',
+            constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
+            cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
+                                    .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''',
+            cudnn=True,
+            input_size=(1, 2, 4, 5, 4),
+            desc='dilated',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_size=(2, 3, 2, 2, 2),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_size=(3, 2, 2, 2),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True),
+            skip_half=True,
+            desc='complex'
+        ),
+        dict(
+            module_name='Embedding',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
+        ),
+        dict(
+            module_name='Embedding',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+            check_gradgrad=False,
+            desc='discontiguous',
+            default_dtype=torch.double,
+            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='mean',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
+            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+            check_gradgrad=False,
+            desc='discontiguous',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3, None, 2., False, 'sum'),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='sum',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3, None, 2., False, 'max'),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='max',
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_mean_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_sum_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_max_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_sparse',
+            constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''',
+            input_fn=lambda: torch.randperm(2).repeat(1, 2),
+            check_gradgrad=False,
+            has_sparse_gradients=True,
+        ),
+        dict(
+            constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
+            input_fn=lambda: torch.randperm(2).repeat(1, 2),
+            fullname='Embedding_sparse',
+            check_gradgrad=False,
+            has_sparse_gradients=True,
+        ),
+        dict(
+            module_name='PixelShuffle',
+            constructor_args=(3,),
+            cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
+            input_size=(1, 9, 4, 4),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PixelUnshuffle',
+            constructor_args=(3,),
+            cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
+            input_size=(1, 1, 12, 12),
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_nearest_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(0, 2, 4),
+            fullname='interpolate_nearest_1d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(1, 2, 3),
+            fullname='interpolate_nearest_tuple_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt).scale_factor(std::vector({4.})).mode(torch::kNearest)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_nearest_scale_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 3),
+            fullname='interpolate_linear_tuple_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4.}))
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_scale_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4),
+            fullname='interpolate_linear_1d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_1d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4.}))
+                                .mode(torch::kLinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_scale_1d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({2, 2}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 128, 1, 1),
+            fullname='interpolate_nearest_2d_launch_configs',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_nearest_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 16}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 3, 4),
+            fullname='interpolate_nearest_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_nearest_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_nearest_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_bilinear_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3),
+            fullname='interpolate_bilinear_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 2.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_shared_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_skewed_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_tuple_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_bicubic_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3),
+            fullname='interpolate_bicubic_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 2.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_shared_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_skewed_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_tuple_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bicubic', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_nearest_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(0, 2, 4, 4, 4),
+            fullname='interpolate_nearest_3d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 16, 16}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 3, 4, 4),
+            fullname='interpolate_nearest_tuple_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4., 4.}))
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_nearest_scale_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_trilinear_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4, 4),
+            fullname='interpolate_trilinear_3d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
+                                        scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3, 3),
+            fullname='interpolate_trilinear_tuple_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({3., 3., 3.}))
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            fullname='interpolate_trilinear_scale_3d',
+            # See https://github.com/pytorch/pytorch/issues/5006
+            precision=3e-4,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
+                                        mode='trilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 2, 3, 3),
+            fullname='interpolate_trilinear_tuple_3d_align_corners',
+            pickle=False,
+            default_dtype=torch.double
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({3., 3., 3.}))
+                                .mode(torch::kTrilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 3, 4, 4),
+            fullname='interpolate_trilinear_scale_3d_align_corners',
+            # See https://github.com/pytorch/pytorch/issues/5006
+            precision=3e-4,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=-1),
+            cpp_options_args='F::SoftmaxFuncOptions(-1)',
+            input_size=(2, 128),  # trigger the last-dim algo in CUDA
+            fullname='softmax_lastdim',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
+            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
+            input_size=(2, 128),
+            fullname='softmax_lastdim_dtype',
+            pickle=False,
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1),
+            cpp_options_args='F::SoftmaxFuncOptions(1)',
+            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
+            fullname='softmax_spatial_special',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1),
+            cpp_options_args='F::SoftmaxFuncOptions(1)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='softmax_spatial',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
+            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='softmax_spatial_dtype',
+            pickle=False,
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=0),
+            cpp_options_args='F::SoftmaxFuncOptions(0)',
+            input_size=(2, 3, 4, 5),
+            fullname='softmax_functional_dim0',
+            test_cuda=False,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=3),
+            cpp_options_args='F::SoftmaxFuncOptions(3)',
+            input_size=(2, 3, 4, 5),
+            fullname='softmax_functional_dim3',
+            test_cuda=False,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=-1),
+            cpp_options_args='F::SoftmaxFuncOptions(-1)',
+            input_size=(),
+            fullname='softmax_functional_scalar',
+            test_cuda=False,
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=-1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(-1)',
+            input_size=(2, 128),  # trigger the last-dim algo in CUDA
+            fullname='log_softmax_lastdim',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
+            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
+            fullname='log_softmax_spatial_special',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='log_softmax_spatial',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=0),
+            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
+            input_size=(2, 3, 4, 5),
+            fullname='log_softmax_dim0',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=3),
+            cpp_options_args='F::LogSoftmaxFuncOptions(3)',
+            input_size=(2, 3, 4, 5),
+            fullname='log_softmax_dim3',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=0),
+            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
+            input_size=(),
+            fullname='log_softmax_scalar',
+            pickle=False,
+        ),
+        dict(
+            fullname='Unfold',
+            constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(2, 4, 3, 3),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold',
+            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(2, 16, 4),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_no_batch_dim_input',
+            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(16, 4),
+            check_gradgrad=False,
+            ref=single_batch_reference_fn,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Unfold_int_input',
+            constructor=lambda: nn.Unfold(2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)',
+            input_size=(2, 4, 3, 3),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_int_input',
+            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
+            input_size=(2, 16, 4),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_no_batch_dim_int_input',
+            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
+            input_size=(16, 4),
+            ref=single_batch_reference_fn,
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='RReLU',
+            constructor_args=(0.1, 0.9),
+            cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
+            input_size=(),
+            desc='with_up_down_scalar',
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
+            desc='broadcast_lhs',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
+            desc='broadcast_rhs',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            constructor_args=(1.5, 1e-05, True),
+            cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+            desc='with_non_default_args',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(8), torch.randn(8)),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerEncoderLayer',
+            constructor_args=(4, 2, 16, 0.0),
+            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
+                                    .dim_feedforward(16)
+                                    .dropout(0.0)''',
+            input_size=(2, 3, 4),
+            desc='relu_activation',
+            with_tf32=True,
+            tf32_precision=0.1,
+            # TODO(#50743): figure out the error
+            # RuntimeError: The size of tensor a (6) must match the size of tensor b (4)
+            # at non-singleton dimension 2
+            check_batched_grad=False,
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerEncoderLayer',
+            constructor_args=(4, 2, 8, 0.0, F.gelu),
+            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kGELU)''',
+            input_size=(2, 3, 4),
+            check_gradgrad=False,
+            desc='gelu_activation',
+            with_tf32=True,
+            tf32_precision=0.08 if SM90OrLater else 0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerDecoderLayer',
+            constructor_args=(4, 2, 8, 0.0),
+            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
+            check_gradgrad=False,
+            desc='relu_activation',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerDecoderLayer',
+            constructor_args=(4, 2, 8, 0.0, F.gelu),
+            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kGELU)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
+            check_gradgrad=False,
+            desc='gelu_activation',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Transformer',
+            constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu),
+            cpp_constructor_args='''torch::nn::TransformerOptions()
+                                    .d_model(4)
+                                    .nhead(2)
+                                    .num_encoder_layers(2)
+                                    .num_decoder_layers(2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kReLU)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
+            check_gradgrad=False,
+            desc='multilayer_coder',
+            with_tf32=True,
+            tf32_precision=0.05 if SM90OrLater else 0.03,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Linear',
+            constructor_args=(3, 5),
+            cpp_constructor_args='torch::nn::LinearOptions(3, 5)',
+            input_fn=lambda: torch.rand(3),
+            reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1],
+            desc="no_batch_dim",
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Flatten',
+            cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)',
+            constructor_args=(-3, -1),
+            input_size=(3, 4, 5),
+            reference_fn=single_batch_reference_fn,
+            desc="no_batch_dim",
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Unflatten',
+            cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})',
+            constructor_args=(-2, torch.Size([2, 2])),
+            input_size=(3, 4, 5),
+            reference_fn=single_batch_reference_fn,
+            desc="no_batch_dim",
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='LayerNorm',
+            constructor_args=([56, 56, 56], 1e-5, False),
+            cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
+            input_size=(4, 56, 56, 56),
+            cudnn=True,
+            check_eval=True,
+            gradcheck_fast_mode=True,
+            check_half=True,
+            desc='3d_no_affine_large_feature',
+        ),
+    ]
+
+    # add conv padding mode tests:
+    for padding_mode, cpp_padding_mode in zip(
+            ['reflect', 'circular', 'replicate', 'zeros'],
+            ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros'], strict=True):
+        # conv signature:
+        #     in_channels, out_channels, kernel_size, stride=1,
+        #     padding=0, dilation=1, groups=1,
+        #     bias=True, padding_mode='zeros'
+        for d in (1, 2, 3):
+            if d == 3 and padding_mode == 'reflect':
+                # FIXME: remove after implementing reflection pad 3d
+                #        https://github.com/pytorch/pytorch/issues/27655
+                continue
+            padding = tuple(range(1, d + 1))
+            cpp_padding = '{' + ', '.join(map(str, padding)) + '}'
+            input_size = (2, 2) + (4,) * d
+            output_size = (2, 3) + tuple(p + 1 for p in padding)  # simplified from `(4 + 2 * p - 3) // 2 + 1`
+            new_module_tests.append(
+                dict(
+                    module_name=f'Conv{d}d',
+                    constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode),
+                    cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3)
+                                            .stride(2)
+                                            .padding({cpp_padding})
+                                            .dilation(1)
+                                            .groups(1)
+                                            .bias(true)
+                                            .padding_mode({cpp_padding_mode})''',
+                    input_size=input_size,
+                    output_size=output_size,
+                    cudnn=True,
+                    desc=f'{padding_mode}_stride2_pad2',
+                    with_tf32=True,
+                    tf32_precision=0.05,
+                    default_dtype=torch.double,
+                ),
+            )
+
+    # Check that non linear activations work with no batch dimensions
+    non_linear_activations_no_batch = [
+        'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
+        'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
+        'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
+        'Tanhshrink', 'Threshold'
+    ]
+    non_linear_activations_extra_info: dict[str, dict] = {
+        'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
+        'Threshold': {'constructor_args': (2., 1.)},
+        'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
+        'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
+        # For RRelu, test that compare CPU and GPU results fail because RNG
+        # is different between CPU and GPU
+        'RReLU': {'test_cuda': False, 'default_dtype': torch.double},
+        'ELU': {'default_dtype': torch.double},
+        'GELU': {'default_dtype': torch.double},
+        'GLU': {'default_dtype': torch.double},
+        'Hardshrink': {'default_dtype': torch.double},
+        'Hardtanh': {'default_dtype': torch.double},
+        'LeakyReLU': {'default_dtype': torch.double},
+        'LogSigmoid': {'default_dtype': torch.double},
+        'Mish': {'default_dtype': torch.double},
+        'PReLU': {'default_dtype': torch.double},
+        'ReLU6': {'default_dtype': torch.double},
+        'ReLU': {'default_dtype': torch.double},
+        'SELU': {'default_dtype': torch.double},
+        'SiLU': {'default_dtype': torch.double},
+        'Sigmoid': {'default_dtype': torch.double},
+        'Softplus': {'default_dtype': torch.double},
+        'Softshrink': {'default_dtype': torch.double},
+        'Softsign': {'default_dtype': torch.double},
+        'Tanh': {'default_dtype': torch.double},
+        'Tanhshrink': {'default_dtype': torch.double},
+    }
+    for non_linear_activation in non_linear_activations_no_batch:
+        activation_test_info = dict(
+            module_name=non_linear_activation,
+            input_size=(4,),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            test_cpp_api_parity=False,
+        )
+        extra_info = non_linear_activations_extra_info.get(non_linear_activation, {})
+        activation_test_info.update(extra_info)
+        new_module_tests.append(activation_test_info)
+
+
+    return new_module_tests
+
+
+def kldivloss_reference(input, target, reduction='mean', log_target=False):
+    if log_target:
+        result = torch.exp(target) * (target - input)
+    else:
+        result = target * (target.log() - input)
+    if reduction == 'mean':
+        return result.mean()
+    elif reduction == 'sum':
+        return result.sum()
+    elif reduction == 'batchmean' and result.dim() != 0:
+        return result.sum() / result.size(0)
+    return result
+
+
+def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
+                        reduction='mean'):
+    assert input.dim() >= 3
+    N = input.size(0)
+    C = input.size(1)
+    out_size = (N,) + input.size()[2:]
+    output = torch.zeros(out_size).type_as(input)
+
+    if weight is None:
+        weight = torch.ones(C).type_as(input)
+    total_weight = 0
+    for tup in product(*[range(size) for size in out_size]):
+        t_nx = target[tup]
+        norm = 0. if ignore_index == t_nx else weight[t_nx].item()
+        input_index = list(tup)
+        input_index.insert(1, t_nx)
+        output[tup] = -input[tuple(input_index)] * norm
+        total_weight += norm
+
+    if reduction == 'mean':
+        return output.sum() / total_weight
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean',
+                                             label_smoothing=0.0):
+    assert input.dim() >= 2
+
+    input = torch.log_softmax(input, 1)
+    C = input.size(1)
+    if weight is None:
+        weight = torch.ones(C).type_as(input)
+    weight = weight.view(1, C, *(1 for _ in input.shape[2:]))
+
+    if label_smoothing > 0.0:
+        assert label_smoothing <= 1.0
+        target = (target * (1 - label_smoothing) + label_smoothing / C)
+
+    output = -(input * target * weight).sum(dim=1)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100,
+                                                reduction='mean', label_smoothing=0.0):
+    log_softmax_input = torch.log_softmax(input, 1)
+    nllloss = F.nll_loss(
+        log_softmax_input,
+        target,
+        weight,
+        ignore_index=ignore_index,
+        reduction=reduction)
+
+    if label_smoothing == 0.0:
+        return nllloss
+
+    assert 0.0 < label_smoothing <= 1.0
+
+    input = torch.log_softmax(input, 1)
+    C = input.size(1)
+    if weight is not None:
+        input = input * weight.view(1, C, *(1 for _ in input.shape[2:]))
+
+    smooth_loss = -torch.sum(input, 1)
+
+    ignore_mask = target == ignore_index
+    smooth_loss.masked_fill_(ignore_mask, 0.0)
+
+    if reduction == 'mean':
+        if weight is not None:
+            # TODO: This code can path can be removed if #61309 is resolved
+            # loss is normalized by the weights to be consistent with nll_loss_nd
+            ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum()
+        else:
+            ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not()))
+    elif reduction == 'sum':
+        ret = torch.sum(smooth_loss)
+    else:
+        ret = smooth_loss
+
+    return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C)
+
+
+def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean',
+                                 label_smoothing=0.0):
+    if input.shape == target.shape:
+        return cross_entropy_loss_prob_target_reference(
+            input,
+            target,
+            weight=weight,
+            reduction=reduction,
+            label_smoothing=label_smoothing)
+    else:
+        return cross_entropy_loss_indices_target_reference(
+            input, target, weight=weight, reduction=reduction,
+            ignore_index=ignore_index, label_smoothing=label_smoothing
+        )
+
+
+def nllloss_reference(input, target, weight=None, ignore_index=-100,
+                      reduction='mean'):
+
+    def nll_loss_helper(input, target, weight, ignore_index):
+        if target == ignore_index:
+            return (0, 0)
+        norm = 1 if weight is None else weight[target]
+        result = -input[target] * norm
+        return (result, norm)
+
+    losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
+                          for i, t in zip(input, target, strict=True)]
+    losses, weights = zip(*losses_and_weights, strict=True)
+    losses_tensor = input.new_tensor(losses)
+    if reduction == 'mean':
+        return sum(losses_tensor) / sum(weights)
+    elif reduction == 'sum':
+        return sum(losses_tensor)
+    else:
+        return losses_tensor
+
+
+def smoothl1loss_reference(input, target, reduction='mean', beta=1.0):
+    abs_diff = (input - target).abs()
+    ge_beta_mask = (abs_diff >= beta).type_as(abs_diff)
+    lt_beta_mask = (abs_diff < beta).type_as(abs_diff)
+    # when beta <= 0 we should just use l1_loss
+    if beta == 0:
+        output = abs_diff
+    else:
+        output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def huberloss_reference(input, target, reduction='mean', delta=1.0):
+    abs_diff = (input - target).abs()
+    ge_delta_mask = (abs_diff >= delta)
+    lt_delta_mask = (abs_diff < delta)
+    output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def _multilabelmarginloss_reference(input, target):
+    targets = []
+    for target_index in target:
+        if target_index < 0:
+            break
+        targets.append(target_index)
+
+    sum = 0
+    for target_index in targets:
+        for i in range(len(input)):
+            if i not in targets:
+                sum += max(0, 1 - input[target_index] + input[i])
+
+    return sum
+
+
+def multilabelmarginloss_reference(input, target, reduction='mean'):
+    # make everything 2-dimensional
+    input_dim = input.dim()
+    if input.dim() < 2:
+        assert target.dim() < 2
+        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
+        target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0)
+
+    n = input.size(0)
+    dim = input.size(1)
+    output = input.new(n).zero_()
+    for i in range(n):
+        output[i] = _multilabelmarginloss_reference(input[i], target[i])
+
+    if reduction == 'mean':
+        return output.mean() / dim
+    elif reduction == 'sum':
+        return output.sum() / dim
+    elif input_dim < 2:
+        # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
+        # back to correct dimensionality
+        return output.squeeze() / dim
+    else:
+        return output / dim
+
+
+def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
+    margin_clamp = (margin - input).clamp(min=0).type_as(input)
+    output = torch.where(target == 1, input, margin_clamp)
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def softmarginloss_reference(input, target, reduction='mean'):
+    output = (1 + (-input * target).exp()).log()
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def _multimarginloss_reference(input, target_idx, p, margin, weight):
+    if weight is None:
+        weight = input.new(len(input)).fill_(1)
+
+    output = 0
+    for i in range(len(input)):
+        if i != target_idx:
+            output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p)
+    return output
+
+
+def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
+    if input.dim() < 2:
+        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
+
+    target_dim = target.dim()
+    if target.dim() == 0:
+        target = target.unsqueeze(0)
+
+    n = input.size(0)
+    dim = input.size(1)
+    output = input.new(n)
+    for x in range(n):
+        output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
+
+    if reduction == 'mean':
+        return output.mean() / dim
+    elif reduction == 'sum':
+        return output.sum() / dim
+    elif target_dim == 0:
+        return output.squeeze(0) / dim
+    return output / dim
+
+
+def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
+    def _cos(a, b):
+        cos = a.new(a.size(0))
+        for i in range(a.size(0)):
+            cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
+        return cos
+
+    output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
+                                reduction='mean'):
+    d_p = torch.pairwise_distance(anchor, positive, p, eps)
+    d_n = torch.pairwise_distance(anchor, negative, p, eps)
+    if swap:
+        d_s = torch.pairwise_distance(positive, negative, p, eps)
+        d_n = torch.min(d_n, d_s)
+
+    output = torch.clamp(margin + d_p - d_n, min=0.0)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
+    output = (-target * (input1 - input2) + margin).clamp(min=0)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space
+def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
+    input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
+    target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
+    dt = log_probs.dtype
+    log_probs = log_probs.double()  # we need the accuracy as we are not in logspace
+    targets = targets.long()
+    cum_target_lengths = target_lengths.cumsum(0)
+    losses = []
+    for i in range(log_probs.size(1)):
+        input_length = input_lengths[i].item()
+        target_length = target_lengths[i].item()
+        cum_target_length = cum_target_lengths[i].item()
+        targets_prime = targets.new_full((2 * target_length + 1,), blank)
+        if targets.dim() == 2:
+            targets_prime[1::2] = targets[i, :target_length]
+        else:
+            targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
+        probs = log_probs[:input_length, i].exp()
+        alpha = log_probs.new_zeros((target_length * 2 + 1,))
+        alpha[0] = probs[0, blank]
+        alpha[1] = probs[0, targets_prime[1]]
+        mask_third = (targets_prime[:-2] != targets_prime[2:])
+        for t in range(1, input_length):
+            alpha_next = alpha.clone()
+            alpha_next[1:] += alpha[:-1]
+            alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
+            alpha = probs[t, targets_prime] * alpha_next
+        losses.append(-alpha[-2:].sum().log()[None])
+    output = torch.cat(losses, 0)
+    if reduction == 'mean':
+        output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
+    elif reduction == 'sum':
+        output = output.sum()
+    output = output.to(dt)
+    return output
+
+
+loss_reference_fns: dict['str', Callable] = {
+    'KLDivLoss': kldivloss_reference,
+    'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
+    'NLLLoss': nllloss_reference,
+    'NLLLossNd': nlllossNd_reference,
+    'SmoothL1Loss': smoothl1loss_reference,
+    'HuberLoss': huberloss_reference,
+    'MultiLabelMarginLoss': multilabelmarginloss_reference,
+    'HingeEmbeddingLoss': hingeembeddingloss_reference,
+    'SoftMarginLoss': softmarginloss_reference,
+    'MultiMarginLoss': multimarginloss_reference,
+    'CosineEmbeddingLoss': cosineembeddingloss_reference,
+    'TripletMarginLoss': tripletmarginloss_reference,
+    'MarginRankingLoss': marginrankingloss_reference,
+    'CTCLoss': ctcloss_reference,
+    'CrossEntropyLoss': cross_entropy_loss_reference
+}
+
+
+criterion_tests = []
+
+
+def single_batch_reference_criterion_fn(*args):
+    """Reference function for criterion supporting no batch dimensions.
+
+    The criterion is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    criterion = args[-1]
+
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    def flatten(xs):
+        result = []
+        if isinstance(xs, (list, tuple)):
+            for x in xs:
+                result.extend(flatten(x))
+        else:
+            result.append(xs)
+        return result
+
+    single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])
+
+    output = criterion(*single_batch_input_args)
+    reduction = get_reduction(criterion)
+
+    if reduction == 'none':
+        return output.squeeze(0)
+    # reduction is 'sum' or 'mean' which results in a scalar
+    return output
+
+
+# Check that regression criterion work with no batch dimensions
+regression_criterion_no_batch = [
+    'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
+]
+reductions = ['none', 'mean', 'sum']
+for name, reduction in product(regression_criterion_no_batch, reductions):
+    regression_test_info = dict(
+        fullname=f"{name}_no_batch_dim_{reduction}",
+        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
+        input_size=(3, ),
+        target_size=(3, ),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=False,
+        default_dtype=torch.double,
+    )
+    criterion_tests.append(regression_test_info)
+
+
+for reduction in reductions:
+    regression_test_info = dict(
+        fullname=f"KLDivLoss_no_batch_dim_{reduction}",
+        constructor=lambda: nn.KLDivLoss(reduction=reduction),
+        input_fn=lambda: torch.rand((3,)).log(),
+        target_fn=lambda: torch.rand((3,)),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=False,
+        default_dtype=torch.double,
+    )
+    criterion_tests.append(regression_test_info)
+
+
+# Check that classification criterion work with no batch dimensions
+# List of tuples of (name, input_fn, target_fn)
+classification_criterion_no_batch = [
+    (
+        'BCELoss',
+        lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)),
+        lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double)
+    ),
+    ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)),
+    ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
+    ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])),
+    ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
+    ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)),
+    (
+        'CosineEmbeddingLoss',
+        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
+        lambda: torch.tensor(1, dtype=torch.double)
+    ),
+    # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target
+    ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()),
+    # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
+    (
+        'TripletMarginLoss',
+        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
+        lambda: torch.randn(9, dtype=torch.double)
+    ),
+    ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
+]
+classification_criterion_no_batch_extra_info: dict[str, dict] = {
+    'MultiLabelMarginLoss': {'check_gradgrad': False},
+}
+# TODO : Fix these discrepancies
+classification_cpp_parity = {
+    'BCELoss': False,
+    'BCEWithLogitsLoss': False,
+    'HingeEmbeddingLoss': False,
+    'NLLLoss': False,
+    'SoftMarginLoss': False,
+}
+reductions = ['none', 'mean', 'sum']
+for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch,
+                                                      reductions):
+    classification_test_info = dict(
+        fullname=f"{name}_no_batch_dim_{reduction}",
+        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
+        input_fn=lambda f=input_fn: f(),
+        target_fn=lambda f=target_fn: f(),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=True,
+        has_parity=classification_cpp_parity.get(name, True)
+    )
+    extra_info = classification_criterion_no_batch_extra_info.get(name, {})
+    classification_test_info.update(extra_info)
+    criterion_tests.append(classification_test_info)
+
+
+class NNTestCase(TestCase):
+
+    # _forward is defined in classes inheriting from NNTestCase
+    @abstractmethod
+    def _forward(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_parameters(self, module: nn.Module) -> tuple[list[nn.Parameter], list[nn.Parameter]]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _zero_grad_parameters(self, module: nn.Module) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _backward(self, module: nn.Module,
+                  input: _TensorOrTensors, output: torch.Tensor,
+                  grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
+                  create_graph: bool = False):
+        raise NotImplementedError
+
+    def _jacobian(self, input, num_out):
+        if isinstance(input, tuple):
+            return tuple(self._jacobian(elem, num_out) for elem in input)
+        elif isinstance(input, list):
+            return [self._jacobian(elem, num_out) for elem in input]
+        else:
+            return torch.zeros(input.nelement(), num_out)
+
+    def _flatten_tensors(self, x):
+        if isinstance(x, torch.Tensor):
+            if x.is_sparse:
+                return x.to_dense().view(-1)
+            else:
+                return x.view(-1)
+        else:
+            return tuple(self._flatten_tensors(a) for a in x)
+
+    def _zero_grad_input(self, input):
+        if isinstance(input, torch.Tensor):
+            if input.requires_grad and input.grad is not None:
+                input.grad.zero_()
+                input.grad.detach_()
+        else:
+            for i in input:
+                self._zero_grad_input(i)
+
+    def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
+        output = self._forward(module, input)
+        output_size = output.nelement()
+
+        if jacobian_input:
+            jacobian_inp = self._jacobian(input, output_size)
+            flat_jacobian_input = list(_iter_tensors(jacobian_inp))
+
+        if jacobian_parameters:
+            num_param = sum(p.numel() for p in self._get_parameters(module)[0])
+            jacobian_param = torch.zeros(num_param, output_size)
+
+        for i in range(output_size):
+            param, d_param = self._get_parameters(module)
+            # make non grad zeros
+            d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param, strict=True)]
+
+            d_out = torch.zeros_like(output)
+            flat_d_out = d_out.view(-1)
+            flat_d_out[i] = 1
+
+            if jacobian_parameters:
+                self._zero_grad_parameters(module)
+            # Tensors will accumulate gradient from multiple steps
+            if jacobian_input:
+                self._zero_grad_input(input)
+            d_input = self._backward(module, input, output, d_out)
+
+            if jacobian_input:
+                for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input), strict=True):
+                    jacobian_x[:, i] = d_x.contiguous().view(-1)
+            if jacobian_parameters:
+                jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
+
+        res: tuple[torch.Tensor, ...] = ()
+        if jacobian_input:
+            res += jacobian_inp,
+        if jacobian_parameters:
+            res += jacobian_param,
+
+        return res
+
+    def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
+        def fw(*input):
+            return self._forward(module, input).detach()
+
+        res: tuple[torch.Tensor, ...] = ()
+        if jacobian_input:
+            res += _get_numerical_jacobian(fw, input, eps=1e-6),
+        if jacobian_parameters:
+            param, _ = self._get_parameters(module)
+            to_cat = []
+            for p in param:
+                jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
+                # get_numerical_jacobian returns a list of tuples but we require a tensor
+                to_cat.append(jacobian[0][0])
+            res += (torch.cat(to_cat, 0),)
+        return res
+
+    def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True):
+        jacobian_parameters = bool(self._get_parameters(module)[0])
+        analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
+        numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
+        analytical_t = list(_iter_tensors(analytical))
+        numerical_t = list(_iter_tensors(numerical))
+
+        differences = []
+        for a, n in zip(analytical_t, numerical_t, strict=True):
+            if a.numel() != 0:
+                differences.append(a.add(n, alpha=-1).abs().max())
+            # TODO: compare structure (ensure analytic jacobian has correct shape)
+        if len(differences) > 0:
+            self.assertLessEqual(max(differences), PRECISION)  # type: ignore[type-var]
+
+
+class TestBase:
+
+    _required_arg_names = {'constructor_args', 'input', 'extra_args'}
+
+    def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
+        self.desc = desc
+        self.fullname = fullname
+        self.constructor = constructor
+        self.reference_fn = reference_fn
+        for name in self._required_arg_names:
+            if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
+                if name in {'constructor_args', 'extra_args'}:
+                    kwargs[name] = ()
+                else:
+                    raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!")
+        self._extra_kwargs = kwargs
+        self._arg_cache = {}
+
+    def get_name(self):
+        if self.fullname is not None:
+            return 'test_' + self.fullname
+
+        test_name = 'test_' + self.constructor.__name__
+        if self.desc:
+            test_name += '_' + self.desc
+        return test_name
+
+    def _unpack(self, value):
+        if isinstance(value, torch.Tensor):
+            return value
+        elif is_iterable(value):
+            return type(value)(self._unpack(v) for v in value)
+        else:
+            return value
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', True)
+
+    @property
+    def extra_args(self):
+        return self._get_arg('extra_args', True)
+
+    def _get_arg(self, name, unpack):
+        assert name in self._required_arg_names
+
+        if name not in self._arg_cache:
+            fn_name = name + '_fn'
+            size_name = name + '_size'
+
+            if name in self._extra_kwargs:
+                self._arg_cache[name] = self._extra_kwargs[name]
+            elif fn_name in self._extra_kwargs:
+                self._arg_cache[name] = self._extra_kwargs[fn_name]()
+            else:
+                assert size_name in self._extra_kwargs, \
+                    f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}"
+
+                def map_tensor_sizes(sizes):
+                    if isinstance(sizes, list):
+                        return [map_tensor_sizes(s) for s in sizes]
+                    elif isinstance(sizes, torch.Tensor):
+                        return sizes.double()
+                    else:
+                        return torch.randn(sizes)
+
+                self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
+
+        return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
+
+    def _get_input(self, unpack=True):
+        return self._get_arg('input', unpack)
+
+    def __call__(self, test_case):
+        raise NotImplementedError
+
+
+class ModuleTest(TestBase):
+
+    @abstractmethod
+    def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
+        raise NotImplementedError
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.jacobian_input = kwargs.get('jacobian_input', True)
+        self.should_test_cuda = kwargs.get('test_cuda', True)
+        self.should_test_pickle = kwargs.get('pickle', True)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.FIXME_no_cuda_gradgrad_comparison = \
+            kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
+        self.precision = kwargs.get('precision', 2e-4)
+        self.check_forward_only = kwargs.get('check_forward_only', False)
+        self.default_dtype = kwargs.get('default_dtype')
+        if self.default_dtype is None:
+            self.default_dtype = torch.get_default_dtype()
+
+    def __call__(self, test_case):
+        with set_default_dtype(self.default_dtype):
+            module = self.constructor(*self.constructor_args)
+            input = self._get_input()
+
+            if self.reference_fn is not None:
+                out = test_case._forward(module, input)
+                ref_input = deepcopy(input)
+                ref_module = deepcopy(module)
+                expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module)
+                test_case.assertEqual(out, expected_out, exact_dtype=False)
+            if self.check_forward_only:
+                return
+            self.test_noncontig(test_case, module, input)
+
+            if self.should_test_pickle:
+                # TODO: do this with in-memory files as soon as torch.save will support it
+                with tempfile.TemporaryFile() as f:
+                    test_case._forward(module, input)
+                    torch.save(module, f)
+                    f.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    module_copy = torch.load(f, weights_only=False)
+                    test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
+
+            self._do_test(test_case, module, input)
+
+    def noncontiguize(self, obj):
+        if isinstance(obj, list):
+            return [self.noncontiguize(o) for o in obj]
+        elif isinstance(obj, tuple):
+            return tuple(self.noncontiguize(o) for o in obj)
+        tensor = obj
+        ndim = tensor.dim()
+        # Always making only the last dimension noncontiguous is easy to hide
+        # bugs because .view(-1) will still work. So try to find a dim with size
+        # > 1 and make that non-contiguous, i.e., stack + select on the
+        # dimension directly after that.
+        dim = ndim
+        for d in range(ndim):
+            if tensor.size(d) > 1:
+                dim = d + 1
+                break
+        noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
+        assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous()
+        noncontig.requires_grad = tensor.requires_grad
+        return noncontig
+
+    def test_noncontig(self, test_case, module, input):
+        # check no scalars, can't make non-contig
+        if isinstance(input, torch.Tensor) and input.dim() == 0:
+            return
+        if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
+            return
+
+        test_case._zero_grad_parameters(module)
+        test_case._zero_grad_input(input)
+        with freeze_rng_state():
+            output = test_case._forward(module, input)
+            if getattr(module, "return_indices", False):
+                output = output[0]
+            grad_output = output.new(output.shape).normal_()
+            output = output.clone()
+            d_input = deepcopy(test_case._backward(module, input, output, grad_output))
+            d_param = deepcopy(test_case._get_parameters(module)[1])
+
+        nc_input = self.noncontiguize(input)
+        nc_grad_output = self.noncontiguize(grad_output)
+        for contig_i, contig_g in product((True, False), repeat=2):
+            i = input if contig_i else nc_input
+            # Some ops, e.g., nn.Flatten, return gradient that shares
+            # storage with the grad_output. Hence we copy here.
+            go = deepcopy(grad_output if contig_g else nc_grad_output)
+            test_case._zero_grad_parameters(module)
+            test_case._zero_grad_input(i)
+            with freeze_rng_state():
+                out = test_case._forward(module, i)
+                if getattr(module, "return_indices", False):
+                    out = out[0]
+                grad = test_case._backward(module, i, out, go)
+
+                test_case.assertEqual(out, output)
+                test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0)
+                test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
+
+    def test_cuda(self, test_case):
+        if not TEST_CUDA or not self.should_test_cuda:
+            raise unittest.SkipTest('Excluded from CUDA tests')
+
+        with set_default_dtype(self.default_dtype):
+            cpu_input = self._get_input()
+
+            type_map = {torch.double: torch.float}
+            cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
+
+            is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple)
+
+            gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
+
+            cpu_module = self.constructor(*self.constructor_args)
+            gpu_module = self.constructor(*self.constructor_args).float().cuda()
+            cpu_param = test_case._get_parameters(cpu_module)
+            gpu_param = test_case._get_parameters(gpu_module)
+            for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0], strict=True):
+                gpu_p.data.copy_(cpu_p)
+
+            test_case._zero_grad_input(cpu_input_tuple)
+            test_case._zero_grad_input(gpu_input_tuple)
+            test_case._zero_grad_parameters(cpu_module)
+            test_case._zero_grad_parameters(gpu_module)
+            cpu_output = test_case._forward(cpu_module, cpu_input_tuple)
+            gpu_output = test_case._forward(gpu_module, gpu_input_tuple)
+            if getattr(cpu_module, "return_indices", False):
+                cpu_output = cpu_output[0]
+                gpu_output = gpu_output[0]
+            test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False)
+
+            # Run backwards on CPU and GPU and compare results
+            for _ in range(5):
+                cpu_gradOutput = cpu_output.clone().normal_()
+                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
+                cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
+                gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
+                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
+                for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1], strict=True):
+                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)
+
+            # Run double-backwards on CPU and GPU and compare results
+            if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
+                cpu_output = cpu_module(*cpu_input_tuple)
+                gpu_output = gpu_module(*gpu_input_tuple)
+                if getattr(cpu_module, "return_indices", False):
+                    cpu_output = cpu_output[0]
+                    gpu_output = gpu_output[0]
+
+                cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
+                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
+                gpu_gradOutput.requires_grad = True
+
+                cpu_gradInputs = torch.autograd.grad(
+                    cpu_output,
+                    cpu_input_tuple + tuple(cpu_module.parameters()),
+                    cpu_gradOutput,
+                    create_graph=True)
+                gpu_gradInputs = torch.autograd.grad(
+                    gpu_output,
+                    gpu_input_tuple + tuple(gpu_module.parameters()),
+                    gpu_gradOutput,
+                    create_graph=True)
+
+                for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs, strict=True):
+                    test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False)
+
+                # We mix output into the second backwards computation so that
+                # torch.autograd.grad doesn't complain that some inputs
+                # are unreachable (which can happen if you differentiate
+                # only on the gradient.
+                if is_any_input_complex:
+                    outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
+                    outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
+                else:
+                    outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
+                    outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)
+
+                cpu_gg = torch.autograd.grad(
+                    outputs_cpu,
+                    cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
+                    retain_graph=True)
+                gpu_gg = torch.autograd.grad(
+                    outputs_gpu,
+                    gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
+                    retain_graph=True)
+                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
+                for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg, strict=True):
+                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False)
+
+            self.test_noncontig(test_case, gpu_module, gpu_input_tuple)
+
+
+class InputVariableMixin:
+    def _get_input(self):
+        input = TestBase._get_input(self, False)  # type: ignore[arg-type]
+
+        def map_variables(i):
+            if isinstance(i, torch.Tensor):
+                if i.is_floating_point() or i.is_complex():
+                    i.requires_grad = True
+                return i
+            else:
+                return type(i)(map_variables(elem) for elem in i)
+
+        return map_variables(input)
+
+
+class NewModuleTest(InputVariableMixin, ModuleTest):  # type: ignore[misc]
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.cudnn = kwargs.get('cudnn', False)
+        self.check_inplace = kwargs.get('check_inplace', False)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.skip_double = kwargs.get('skip_double', False)
+        self.skip_half = kwargs.get('skip_half', False)
+        self.with_tf32 = kwargs.get('with_tf32', False)
+        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
+        self.test_cpu = kwargs.get('test_cpu', True)
+        self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
+        self.check_batched_grad = kwargs.get('check_batched_grad', True)
+        self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode')
+        self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
+        self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
+
+    def _check_gradients(self, test_case, module, input_tuple):
+        params = tuple(x for x in module.parameters())
+        num_inputs = len(input_tuple)
+
+        def fn_to_gradcheck(*inputs_and_params, **kwargs):
+            assert not kwargs
+            return test_case._forward(module, inputs_and_params[:num_inputs])
+
+        # gradcheck doesn't support operators that take in dense inputs but
+        # return sparse parameters. This only happens in the case of nn.Embedding
+        # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which
+        # is a slightly different version of gradcheck that can handle this.
+        if self.has_sparse_gradients:
+            assert num_inputs == 1
+            test_input_jacobian = torch.is_floating_point(input_tuple[0])
+            test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
+        else:
+            test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
+                                           check_batched_grad=self.check_batched_grad,
+                                           fast_mode=self.gradcheck_fast_mode,
+                                           check_forward_ad=self.supports_forward_ad))
+
+        if self.check_gradgrad:
+            test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
+                                               check_batched_grad=self.check_batched_grad,
+                                               fast_mode=self.gradcheck_fast_mode,
+                                               check_fwd_over_rev=self.supports_fwgrad_bwgrad))
+
+    def _do_test(self, test_case, module, input):
+        num_threads = torch.get_num_threads()
+        torch.set_num_threads(1)
+        input_tuple = input if isinstance(input, tuple) else (input,)
+
+        self._check_gradients(test_case, module, input_tuple)
+
+        # check if module can be printed
+        module.__repr__()
+
+        if self.check_inplace:
+            # check if the inplace variant of the module gives the same result
+            # as the out-of-place
+
+            # check_inplace doesn't support multiple input tensors, since we don't have any modules
+            # that modify the inputs in-place and that accept more than one input
+            assert len(input_tuple) == 1
+            input = input_tuple[0]
+
+            module_ip = self.constructor(*self.constructor_args, inplace=True)
+
+            input_version = input._version
+            with freeze_rng_state():
+                output = module(input)
+            test_case.assertEqual(input._version, input_version)
+
+            input_ip = deepcopy(input)
+            input_ip_clone = input_ip.clone()
+            with freeze_rng_state():
+                output_ip = module_ip(input_ip_clone)
+            test_case.assertNotEqual(input_ip_clone._version, input_version)
+            test_case.assertEqual(output, output_ip)
+            grad = output.data.clone().normal_()
+            if input.grad is not None:
+                with torch.no_grad():
+                    input.grad.zero_()
+            if input_ip.grad is not None:
+                with torch.no_grad():
+                    input_ip.grad.zero_()
+            output.backward(grad)
+            output_ip.backward(grad)
+            test_case.assertEqual(input.grad, input_ip.grad)
+
+        def assert_module_parameters_are(tensor_type, device_id=None):
+            for p in module.parameters():
+                test_case.assertIsInstance(p, tensor_type)
+                if device_id is not None:
+                    test_case.assertEqual(p.get_device(), device_id)
+
+        if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA:
+            # check that cuda() moves module parameters to correct GPU device,
+            # and that float() casts parameters correctly
+            input_tuple = tuple(t.cuda() for t in input_tuple)
+            module.float().cuda()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+            if torch.cuda.device_count() > 1:
+                input_tuple = tuple(t.cuda(1) for t in input_tuple)
+                module.cuda(1)
+                with torch.cuda.device(1):
+                    module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
+        else:
+            # check that float()/double() casters work correctly
+            def to_type(tensor, real, complex):
+                if tensor.is_complex():
+                    return tensor.to(complex)
+                elif tensor.is_floating_point():
+                    return tensor.to(real)
+                else:
+                    return tensor
+
+            def to_half(x):
+                # TODO: torch.complex32 when properly supported
+                return to_type(x, torch.float16, None)
+
+            def to_single(x):
+                return to_type(x, torch.float32, torch.complex64)
+
+            def to_double(x):
+                return to_type(x, torch.float64, torch.complex128)
+
+            # to float
+            input_tuple = tuple(to_single(t) for t in input_tuple)
+            module.float()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.FloatTensor)
+
+            # and back to double
+            input_tuple = tuple(to_double(t) for t in input_tuple)
+            module.double()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.DoubleTensor)
+
+            if TEST_CUDA and self.should_test_cuda:
+                # check that cuda() moves module parameters to correct GPU device,
+                # and that float() casts parameters correctly
+
+                # to GPU0
+                input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
+                module.float().cuda()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                # to CPU
+                input_tuple = tuple(t.cpu() for t in input_tuple)
+                module.cpu()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.FloatTensor)
+
+                # back to GPU0
+                input_tuple = tuple(t.cuda() for t in input_tuple)
+                module.cuda()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                # test that forwards of module runs correctly without cuDNN
+                if self.cudnn:
+                    with torch.backends.cudnn.flags(enabled=False):
+                        module(*input_tuple)
+                        assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                if torch.cuda.device_count() >= 2:
+                    # test cross-GPU transfer works
+                    # to GPU1
+                    input_tuple = tuple(t.cuda(1) for t in input_tuple)
+                    module.cuda(1)
+                    with torch.cuda.device(1):
+                        module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
+
+                if not self.skip_double:
+                    # test double()
+                    input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
+                    module.double().cuda()
+                    module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.DoubleTensor, 0)  # type: ignore[attr-defined]
+
+                # test half()
+                if not self.skip_half:
+                    input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
+                    module.half().cuda()
+                    module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.HalfTensor, 0)  # type: ignore[attr-defined]
+        torch.set_num_threads(num_threads)
+
+    def _get_target(self):
+        return self._get_arg('target', False)
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', False)
+
+
+class CriterionTest(InputVariableMixin, TestBase):  # type: ignore[misc]
+    # TODO: check that criterions don't ignore grad_output
+
+    _required_arg_names = TestBase._required_arg_names.union({'target'})
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.should_test_cuda = kwargs.get('test_cuda', True)
+        self.check_forward_only = kwargs.get('check_forward_only', False)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.check_half = kwargs.get('check_half', True)
+        self.check_bfloat16 = kwargs.get('check_bfloat16', False)
+        self.check_complex = kwargs.get('check_complex', False)
+        self.test_cpu = kwargs.get('test_cpu', True)
+        self.with_tf32 = kwargs.get('with_tf32', True)
+        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
+        self.check_batched_grad = kwargs.get('check_batched_grad', True)
+        self.default_dtype = kwargs.get('default_dtype')
+        if self.default_dtype is None:
+            self.default_dtype = torch.get_default_dtype()
+
+    def __call__(self, test_case):
+        with set_default_dtype(self.default_dtype):
+            module = self.constructor(*self.constructor_args)
+            input = self._get_input()
+
+            # Check that these methods don't raise errors
+            module.__repr__()
+            str(module)
+
+            target = self._get_target()
+
+            if self.reference_fn is not None:
+                out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
+                ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
+                expected_out = self.reference_fn(*ref_args)
+                test_case.assertEqual(out, expected_out)
+
+            if self.check_forward_only:
+                return
+
+            params = tuple(x for x in module.parameters())
+            if not isinstance(input, tuple):
+                inputs = (input,) + params + (target,)
+
+                def apply_fn(input, target, *params):
+                    return module(input, target)
+            else:
+                inputs = input + params + (target,)
+
+                def apply_fn(input1, input2, target, *params):  # type: ignore[misc]
+                    return module(input1, input2, target)
+
+            gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
+
+            if self.check_gradgrad:
+                gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
+
+    def test_cuda(self, test_case, dtype, extra_args=None):
+        def convert_dtype(obj, dtype, requires_grad=False):
+            if isinstance(obj, torch.Tensor):
+                return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
+            elif isinstance(obj, tuple):
+                return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
+            else:
+                return obj
+
+        if not TEST_CUDA or not self.should_test_cuda:
+            raise unittest.SkipTest('Excluded from CUDA tests')
+
+        with set_default_dtype(self.default_dtype):
+            cpu_input = self._get_input()
+            cpu_target = self._get_target()
+            cpu_module = self.constructor(*self.constructor_args)
+            gpu_module = self.constructor(*self.constructor_args)
+
+            # Convert input, target and module parameters to dtype
+            cpu_input = convert_dtype(cpu_input, dtype, True)
+            if cpu_target.is_floating_point() or cpu_target.is_complex():
+                cpu_target = convert_dtype(cpu_target, dtype)
+            cpu_module.type(dtype)
+            gpu_module.type(dtype)
+
+            # GPU setup
+            gpu_input = to_gpu(cpu_input)
+            gpu_target = to_gpu(cpu_target)
+            gpu_module.cuda()
+
+            # torch.HalfTensor doesn't support most operations, converting back to default
+            if dtype in {torch.half, torch.bfloat16}:
+                cpu_input = self._get_input()
+                cpu_target = self._get_target()
+                # Loss modules with weights require consistent input/module weight types
+                cpu_module = self.constructor(*self.constructor_args)
+
+            cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
+            gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
+            # dtype used to be able to be None, so set precision in this way instead of a precision map
+            test_case.assertEqual(cpu_output, gpu_output,
+                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
+
+            cpu_gradInput = test_case._backward_criterion(
+                cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
+            gpu_gradInput = test_case._backward_criterion(
+                gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
+            # dtype used to be able to be None, so set precision in this way instead of a precision map
+            test_case.assertEqual(cpu_gradInput, gpu_gradInput,
+                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
+
+    def _get_target(self):
+        return self._get_arg('target', False)
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', False)
+
+    @property
+    def extra_args(self):
+        return self._get_arg('extra_args', False)
+
+
+def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None):
+    # fp32 compute
+    input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True)
+    if scale_factor is not None:
+        input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_()
+    out1 = op(input1)
+    grad_input1 = torch.randn_like(out1, device=device)
+    out1.backward(grad_input1)
+
+    # bfloat16 compute
+    op_bfp16 = op.bfloat16()
+    input2 = input1.detach().bfloat16().requires_grad_()
+    grad_input2 = grad_input1.bfloat16()
+    out2 = op_bfp16(input2)
+    out2.backward(grad_input2)
+
+    test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False)
+    test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False)
+
+def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False):
+    if not inference:
+        inp.requires_grad_(True)
+    out = module(inp)
+    if not inference:
+        gO = torch.rand_like(out)
+        out.backward(gO)
+    if check_size:
+        test_case.assertEqual(out.size(), inp.size())
+    if not inference:
+        for p in module.parameters():
+            if p.requires_grad:
+                test_case.assertEqual(p.grad, torch.zeros_like(p.grad))
+        test_case.assertEqual(inp.grad, torch.zeros_like(inp))
+
+
+def _create_basic_net():
+    class Layer(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
+            self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7))
+
+    class Net(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.l1 = Layer()
+            self.dummy_param = nn.Parameter(torch.empty(3, 5))
+            self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1))
+
+    l = Layer()
+    n = Net()
+    s = nn.Sequential(n, n)
+
+    return l, n, s
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b41e24b96caf24558c6947b6350c7b9c9ac8b7a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py
@@ -0,0 +1,2303 @@
+# mypy: ignore-errors
+
+import functools
+import itertools
+import sys
+import unittest
+from copy import deepcopy
+from enum import Enum
+from typing import Any, Union
+
+import torch
+from torch import Tensor
+from torch.nn import Parameter
+from torch.optim import (
+    Adadelta,
+    Adafactor,
+    Adagrad,
+    Adam,
+    Adamax,
+    AdamW,
+    ASGD,
+    LBFGS,
+    Muon,
+    NAdam,
+    Optimizer,
+    RAdam,
+    RMSprop,
+    Rprop,
+    SGD,
+    SparseAdam,
+)
+from torch.optim.lr_scheduler import (
+    ConstantLR,
+    ExponentialLR,
+    LinearLR,
+    PolynomialLR,
+    ReduceLROnPlateau,
+    StepLR,
+)
+from torch.testing._internal.common_device_type import tol, toleranceOverride
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+from torch.testing._internal.common_utils import (
+    _TestParametrizer,
+    skipIfMPS,
+    skipIfTorchDynamo,
+    TEST_WITH_TORCHDYNAMO,
+)
+from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
+
+
+CUDA_CONFIG_GPUS = ["cuda", "xpu"]
+
+
+class OptimizerInput:
+    """Contains args / kwargs to be passed to an optimizer constructor."""
+
+    __slots__ = ["params", "kwargs", "desc"]
+
+    def __init__(
+        self,
+        params: Union[
+            list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]]
+        ],
+        kwargs: dict[str, Any],
+        desc: str = "",
+    ):
+        # params can be a list of Tensors OR param_groups OR None
+        self.params = params
+        self.kwargs = kwargs
+        self.desc = desc
+
+    def __repr__(self):
+        return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}"
+
+
+class OptimizerErrorEnum(Enum):
+    """Enumerates when an error is raised when testing optimizers."""
+
+    CONSTRUCTION_ERROR = 0
+    STEP_ERROR = 1
+
+
+class ErrorOptimizerInput:
+    """
+    An OptimizerInput that will cause the optimizer to throw an error when constructed.
+    Includes the type and string of the resulting error.
+    """
+
+    __slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"]
+
+    def __init__(
+        self,
+        optimizer_error_input,
+        *,
+        error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
+        error_type=RuntimeError,
+        error_regex="",
+    ):
+        self.optimizer_error_input = optimizer_error_input
+        self.error_on = error_on
+        self.error_type = error_type
+        self.error_regex = error_regex
+
+
+class OptimizerInfo:
+    """Optimizer information to be used in testing."""
+
+    def __init__(
+        self,
+        optim_cls: Optimizer,  # Class object for the Optimizer under test
+        *,
+        # Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility
+        # to the test using the OptimizerInfo. OptimizerInput.params is likely None.
+        # Can optionally take in device to filter out certain unsupported configs
+        optim_inputs_func,
+        # Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
+        # LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
+        # We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
+        # LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
+        # A few optimizers like SGD and Adam will test more LRSchedulers.
+        scheduler_inputs=(
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        # A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
+        # supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
+        supported_impls: tuple[str, ...] = ("foreach", "differentiable"),
+        # A subset of all flags, signifying which ones were only supported after the
+        # original optimizer had already been released. aka impls where we need to check BC.
+        not_og_supported_flags: tuple[str, ...] = (
+            "foreach",
+            "differentiable",
+            "maximize",
+            "capturable",
+        ),
+        # the optim supports passing in sparse gradients as well as dense grads
+        supports_sparse: bool = False,
+        # the optimizer constructor supports passing in capturable as a kwarg
+        has_capturable_arg: bool = False,
+        # the optim only supports one config: sparse grads w/ dense params, see SparseAdam
+        only_supports_sparse_grads: bool = False,
+        # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
+        # with especially tuned hyperparameters. These only apply if the optimizer supports
+        # sparse parameters or grads.
+        metadata_for_sparse=({}, []),
+        # the optim supports complex parameters
+        supports_complex: bool = True,
+        # whether the optimizer.step() function requires a closure to be passed
+        step_requires_closure: bool = False,
+        # whether the optimizer supports per-param options with parameter groups
+        supports_param_groups: bool = True,
+        # whether the optimizer supports parameters on multiple devices
+        supports_multiple_devices: bool = True,
+        skips=(),  # Indicates which tests to skip
+        decorators=None,  # Additional decorators to apply to generated tests
+        optim_error_inputs_func=None,  # Function to generate optim inputs that error
+        supports_fused_on: tuple[str, ...] = (),
+    ):
+        self.optim_cls = optim_cls
+        self.optim_inputs_func = optim_inputs_func
+        self.scheduler_inputs = scheduler_inputs
+        self.supported_impls = supported_impls
+        self.not_og_supported_flags = not_og_supported_flags
+        self.supports_sparse = supports_sparse
+        self.has_capturable_arg = has_capturable_arg
+        self.metadata_for_sparse = metadata_for_sparse
+        self.only_supports_sparse_grads = only_supports_sparse_grads
+        self.supports_complex = supports_complex
+        self.step_requires_closure = step_requires_closure
+        self.supports_param_groups = supports_param_groups
+        self.supports_multiple_devices = supports_multiple_devices
+        self.decorators = (
+            *(decorators if decorators else []),
+            *(skips if skips else []),
+        )
+        self.optim_error_inputs_func = optim_error_inputs_func
+        self.supports_fused_on = supports_fused_on
+
+    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(
+                    test_class, test_name, device, dtype, param_kwargs
+                ):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
+
+    @property
+    def name(self):
+        return self.optim_cls.__name__
+
+
+class optims(_TestParametrizer):
+    """Decorator for specifying a list of optimizers over which to run a test."""
+
+    def __init__(self, optim_info_iterable, dtypes=None):
+        self.optim_info_list = list(optim_info_iterable)
+
+        # optimizers aren't limited to be one dtype as parameters can have different dtypes
+        # We default to torch.float32, but dtypes should be specified through passed in
+        # parameters.
+        self.dtypes = dtypes if dtypes is not None else [torch.float32]
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if device_cls is None:
+            raise RuntimeError(
+                "The @optims decorator is only intended to be used in a device-specific "
+                "context; use it with instantiate_device_type_tests() instead of "
+                "instantiate_parametrized_tests()"
+            )
+
+        for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes):
+            # Construct the test name; device / dtype parts are handled outside.
+            # See [Note: device and dtype suffix placement]
+            test_name = optim_info.name
+
+            # Construct parameter kwargs to pass to the test.
+            param_kwargs = {"optim_info": optim_info, "dtype": dtype}
+
+            try:
+
+                @functools.wraps(test)
+                def test_wrapper(*args, **kwargs):
+                    return test(*args, **kwargs)
+
+                decorator_fn = functools.partial(
+                    optim_info.get_decorators,
+                    generic_cls.__name__,
+                    test.__name__,
+                    device_cls.device_type,
+                    dtype,
+                )
+
+                yield (test_wrapper, test_name, param_kwargs, decorator_fn)
+            except Exception as ex:
+                # Provides an error message for debugging before rethrowing the exception
+                print(
+                    f"Failed to instantiate {test_name} for module {optim_info.name}!"
+                )
+                raise ex
+
+
+# Helper function for generating error inputs for all optimizers, used below.
+def get_error_inputs_for_all_optims(device, dtype):
+    if _get_device_type(device) == "cpu":
+        # Creating 2D parameters for compatibility with Muon.
+        sample_param = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
+        sample_param2 = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
+        return [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=sample_param,
+                    kwargs={},
+                    desc="invalid param type",
+                ),
+                error_type=TypeError,
+                error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_param, sample_param],
+                    kwargs={},
+                    desc="a param group cannot have duplicate parameters",
+                ),
+                error_type=UserWarning,
+                error_regex=".*a parameter group with duplicate parameters.*",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[{"params": sample_param}, {"params": sample_param}],
+                    kwargs={},
+                    desc="duplicate parameters should not occur across param groups either",
+                ),
+                error_type=ValueError,
+                error_regex="some parameters appear in more than one parameter group",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=torch.tensor([0.001, 0.001])),
+                    desc="Tensor lr must be 1-element",
+                ),
+                error_type=ValueError,
+                error_regex="Tensor lr must be 1-element",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[("weight", sample_param), sample_param2],
+                    kwargs={},
+                    desc="all optimizer params should be with/without names",
+                ),
+                error_type=ValueError,
+                error_regex="all optimizer params should be with/without names. Some param names are missing",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        {"params": [sample_param], "lr": 1e-2},
+                        {"params": [("weight", sample_param2)]},
+                    ],
+                    kwargs={},
+                    desc="all optimizer param groups should be with/without names.",
+                ),
+                error_type=ValueError,
+                error_regex="all optimizer param groups should be with/without names. "
+                "cannot add param group with names to the optimizer",
+            ),
+        ]
+    else:
+        return []
+
+
+# ------------------------------------------------------------------------------------------
+# NOTE: [optimizer kwarg categories]
+# We categorize optimizer kwargs as 3 types:
+#  1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to
+#     algorithms and thus only show up for certain optimizers. There are many of these, so I
+#     do not bother gathering them all and listing them here. The converse to these would be
+#     global flags that every optimizer ideally _should_ support. We break global flags into
+#     2 further categories and list them all below.
+#  2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"]
+#     global-friendly flags are global flags who play nicely with all other global flags,
+#     i.e., are mutually exclusive in function. This means that any pair of the following
+#     flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the
+#     following flags theoretically can be enabled with ANY other global flag, including the
+#     cliquey ones (e.g, capturable and foreach).
+#  3. global-cliquey = ["foreach", "fused", "differentiable"]
+#     global-cliquey flags are global flags that do NOT coexist with other cliquey flags,
+#     usually because they contradict each other in function. For example, one should not flip
+#     both foreach AND fused to True, because they are two differing performance optimizations
+#     in which you can only opt into one.
+#
+# The following optim_inputs_func_* sampling functions only return constructor combinations of
+# optimizer-specific and global-friendly flags. This is because we are confident they would mesh
+# well with additional kwargs. On the flip side of the same coin, we reserve setting the
+# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
+
+
+def optim_inputs_func_adadelta(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "capturable": True},
+            desc="capturable with weight decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_adadelta(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, rho=1.1),
+                    desc="rho should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid rho value: 1.1",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adafactor(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "lr": 0.01},
+            desc="nonzero weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"beta2_decay": -1.0},
+            desc="non-default beta2_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"d": 1.5},
+            desc="non-default clipping threshold d",
+        ),
+    ]
+
+
+def optim_error_inputs_func_adafactor(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
+        complex_param.grad = torch.rand_like(complex_param)
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(eps=(-1e-30, 1e-3)),
+                    desc="epsilon1 should be >= 0",
+                ),
+                error_type=ValueError,
+                error_regex="epsilon1 should be >= 0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(d=0.0),
+                    desc="invalid d",
+                ),
+                error_type=ValueError,
+                error_regex="Clipping threshold d should be >= 1",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(beta2_decay=0.8),
+                    desc="invalid beta2_decay",
+                ),
+                error_type=ValueError,
+                error_regex="beta2_decay should be <= 0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[complex_param],
+                    kwargs=dict(),
+                    desc="does not support complex parameters",
+                ),
+                error_type=RuntimeError,
+                error_regex="Adafactor does not support complex parameters",
+                error_on=OptimizerErrorEnum.STEP_ERROR,
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adagrad(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
+        OptimizerInput(
+            params=None,
+            kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1},
+            desc="initial_accumulator_value",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1},
+            desc="lr_decay",
+        ),  # TODO: Move out to testing in param_group?
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001)},
+            desc="Tensor lr",
+        ),
+    ]
+
+
+def optim_error_inputs_func_adagrad(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, lr_decay=-0.5),
+                    desc="lr_decay must be bigger than 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid lr_decay value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
+# with all implementation code paths...
+def optim_inputs_func_adam(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True},
+            desc="capturable, amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True},
+            desc="Tensor lr with capturable and amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "betas": (torch.tensor([[[0.9]]]), torch.tensor([[0.99]])),
+                "amsgrad": True,
+                "capturable": True,
+            },
+            desc="Tensor lr, Tensor betas, with capturable and amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "betas": (torch.tensor(0.9), torch.tensor(0.99)),
+                "amsgrad": False,
+                "capturable": True,
+            },
+            desc="Tensor lr, Tensor betas, with capturable",
+        ),
+    ]
+    mps_supported_configs = [
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr"
+        ),
+    ]
+
+    total = (
+        [
+            OptimizerInput(params=None, kwargs={}, desc="default"),
+            OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+            OptimizerInput(
+                params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+            ),
+            OptimizerInput(
+                params=None,
+                kwargs={"weight_decay": 0.1, "maximize": True},
+                desc="maximize",
+            ),
+            OptimizerInput(
+                params=None,
+                kwargs={"weight_decay": 0.1, "amsgrad": True},
+                desc="amsgrad",
+            ),
+        ]
+        + (
+            cuda_supported_configs
+            if _get_device_type(device) in CUDA_CONFIG_GPUS
+            else []
+        )
+        + (mps_supported_configs if _get_device_type(device) == "mps" else [])
+    )
+    if dtype == torch.float16:
+        for input in total:
+            """
+            Too small eps will make denom to be zero for low precision dtype
+            denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
+            For example,
+            >>> a
+            tensor([0.], dtype=torch.float16)
+            >>> a + 1e-8
+            tensor([0.], dtype=torch.float16)
+            """
+            input.kwargs["eps"] = 0.1
+    return total
+
+
+def optim_error_inputs_func_adam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-1),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -1",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=torch.tensor(0.001), foreach=True),
+                    desc="lr as Tensor doesn't work with foreach & not capturable",
+                ),
+                error_type=ValueError,
+                error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(0.9, torch.tensor(0.99))),
+                    desc="betas must be either both floats or both Tensors",
+                ),
+                error_type=ValueError,
+                error_regex="betas must be either both floats or both Tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(torch.tensor(0.9), 0.99)),
+                    desc="betas must be either both floats or both Tensors",
+                ),
+                error_type=ValueError,
+                error_regex="betas must be either both floats or both Tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(
+                        lr=1e-2,
+                        betas=(torch.tensor(0.9), torch.tensor(0.99)),
+                        foreach=True,
+                    ),
+                    desc=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
+                ),
+                error_type=ValueError,
+                error_regex=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
+            ),
+        ]
+    if _get_device_type(device) in CUDA_CONFIG_GPUS:
+        sample_tensor = torch.empty((), device=device, dtype=dtype)
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_tensor],
+                    kwargs={"foreach": True, "fused": True},
+                    desc="`fused` and `foreach` cannot be `True` together",
+                ),
+                error_type=RuntimeError,
+                error_regex="`fused` and `foreach` cannot be `True` together",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_tensor],
+                    kwargs={"fused": True, "differentiable": True},
+                    desc="`fused` does not support `differentiable`",
+                ),
+                error_type=RuntimeError,
+                error_regex="`fused` does not support `differentiable`",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adamax(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True},
+            desc="capturable, maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0, "maximize": True, "capturable": True},
+            desc="capturable, maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True},
+            desc="capturable, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.9,
+                "maximize": False,
+                "capturable": True,
+            },
+            desc="capturable, weight_decay, tensor LR",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, weight_decay",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_adamax(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
+                    desc="beta2 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 1: 1.0",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adamw(device, dtype=None):
+    return optim_inputs_func_adam(device, dtype)
+
+
+def optim_error_inputs_func_adamw(device, dtype):
+    return optim_error_inputs_func_adam(device, dtype)
+
+
+def optim_inputs_func_asgd(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"maximize": True, "capturable": True},
+            desc="maximize, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
+            desc="maximize, weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.1,
+                "maximize": True,
+                "capturable": True,
+            },
+            desc="maximize, weight_decay, capturable, tensor LR",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
+        OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
+        OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, nonzero weight_decay",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_asgd(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-0.5),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_lbfgs(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(
+            params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"line_search_fn": "strong_wolfe"},
+            desc="strong_wolfe",
+        ),
+    ]
+
+
+def optim_error_inputs_func_lbfgs(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    return error_inputs
+
+
+def optim_inputs_func_muon(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.2},
+            desc="non-default weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.8},
+            desc="non-default momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"ns_steps": 6},
+            desc="passing alternative ns_steps",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "ns_coefficients": (3.4, -4.7, 2.0),
+            },
+            desc="passing alternative ns_coefficients",
+        ),
+    ]
+
+
+def optim_error_inputs_func_muon(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
+    complex_param.grad = torch.rand_like(complex_param)
+    non_2d_param = torch.rand(2, 3, 4, device=device, dtype=dtype)
+    non_2d_param.grad = torch.rand_like(non_2d_param)
+    param = torch.rand(2, 3, device=device, dtype=dtype)
+    param.grad = torch.rand_like(param)
+    error_inputs += [
+        ErrorOptimizerInput(
+            OptimizerInput(
+                params=[non_2d_param],
+                kwargs=dict(),
+                desc="only support 2D parameters",
+            ),
+            error_type=ValueError,
+            error_regex="Muon only supports 2D parameters",
+            error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
+        ),
+        ErrorOptimizerInput(
+            OptimizerInput(
+                params=[param],
+                kwargs={"adjust_lr_fn": "arbitrary"},
+                desc="only support `original` and `match_rms_adamw`",
+            ),
+            error_type=ValueError,
+            error_regex="Adjust learning rate function arbitrary is not supported",
+            error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
+        ),
+        ErrorOptimizerInput(
+            OptimizerInput(
+                params=[complex_param],
+                kwargs=dict(),
+                desc="does not support complex parameters",
+            ),
+            error_type=RuntimeError,
+            error_regex="Muon does not support complex parameters",
+            error_on=OptimizerErrorEnum.STEP_ERROR,
+        ),
+    ]
+    return error_inputs
+
+
+def optim_inputs_func_nadam(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.9,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+                "capturable": True,
+            },
+            desc="decoupled_weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.9,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+                "capturable": True,
+            },
+            desc="decoupled_weight_decay, capturable",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum_decay": 6e-3},
+            desc="non-zero momentum_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+            },
+            desc="weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
+            desc="weight_decay, momentum_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+            },
+            desc="decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_nadam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum_decay=-0.2),
+                    desc="momentum_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum_decay value: -0.2",
+            ),
+        ]
+    return error_inputs
+
+
+# Weird story bro, NAdam and RAdam do not have maximize.
+def optim_inputs_func_radam(device=None, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "capturable": True,
+                "weight_decay": 0.1,
+            },
+            desc="capturable, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "capturable": True,
+                "weight_decay": 0.1,
+                "decoupled_weight_decay": True,
+            },
+            desc="capturable, weight_decay, decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "capturable": True,
+                "weight_decay": 0.1,
+                "decoupled_weight_decay": True,
+            },
+            desc="capturable, weight_decay, decoupled_weight_decay, tensor LR",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
+        OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
+            desc="decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_radam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-1),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -1",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_rmsprop(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
+            desc="capturable, maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "maximize": True,
+            },
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "centered": True},
+            desc="centered",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "maximize": True,
+                "weight_decay": 0.1,
+            },
+            desc="maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
+            desc="momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+                "centered": True,
+                "momentum": 0.1,
+                "maximize": True,
+            },
+            desc="maximize, centered, weight_decay, w/ momentum",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_rmsprop(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum=-1.0),
+                    desc="momentum should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum value: -1.0",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_rprop(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"step_sizes": (2e-6, 100)},
+            desc="non-default step_sizes",
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_rprop(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
+                    desc="0 < eta1 < 1 < eta2",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid eta values: 1.0, 0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_sgd(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
+        ),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
+        ),
+        OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "dampening": 0.5},
+            desc="dampening",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "weight_decay": 0.1},
+            desc="weight_decay w/ momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
+            desc="nesterov",
+        ),
+    ]
+
+
+def optim_error_inputs_func_sgd(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum=-0.5),
+                    desc="momentum should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_sparseadam(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None, kwargs={"lr": 0.01}, desc="non-default lr"
+        ),  # TODO: Move out to testing in param_group?
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+    ]
+
+
+def optim_error_inputs_func_sparseadam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        torch.zeros(
+                            3, layout=torch.sparse_coo, device=device, dtype=dtype
+                        )
+                    ],
+                    kwargs={},
+                    desc="dense params required",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam requires dense parameter tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        {
+                            "params": [
+                                torch.zeros(
+                                    3,
+                                    layout=torch.sparse_coo,
+                                    device=device,
+                                    dtype=dtype,
+                                )
+                            ]
+                        }
+                    ],
+                    kwargs={},
+                    desc="dense params required in param_groups",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam requires dense parameter tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[torch.rand(2, 3, device=device, dtype=torch.complex64)],
+                    kwargs={},
+                    desc="complex not supported",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam does not support complex parameters",
+            ),
+        ]
+    return error_inputs
+
+
+def _get_device_type(device: Union[str, torch.device]) -> str:
+    # Returns the device type as a string, e.g., "cpu" or "cuda"
+    if isinstance(device, torch.device):
+        device = str(device.type)
+    assert isinstance(device, str)
+    return device.split(":")[0]
+
+
+def _get_optim_inputs_including_global_cliquey_kwargs(
+    device, dtype, optim_info, skip=()
+) -> list[OptimizerInput]:
+    """
+    Return a list of all configs for a given optimizer as a list of OptimizerInputs,
+    including configs that have supported global cliquey kwargs (foreach, fused,
+    differentiable) based on optim_info.supported_impls.
+
+    The configs (optim_inputs) returned by optim_info.optim_inputs_func(...)
+    intentionally do NOT include global cliquey kwargs to give flexibility to tests.
+    For example, testing correctness between toggling foreach on and off is now
+    trivial. That said, we sometimes want to test for all possible configs on an
+    optimizer including all supported flags, so this helper returns all optim inputs.
+    """
+    assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
+        "skip must be a subset of ['foreach', 'fused', 'differentiable']"
+    )
+
+    optim_inputs = optim_info.optim_inputs_func(device)
+
+    supported_impls = tuple(
+        x
+        for x in optim_info.supported_impls
+        if x not in skip
+        and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
+        and (
+            _get_device_type(device) in _get_foreach_kernels_supported_devices()
+            or x != "foreach"
+        )
+    )
+
+    all_optim_inputs = []
+    for optim_input in optim_inputs:
+        # Add the base config where all the flags are False
+        base_kwargs = deepcopy(optim_input.kwargs)
+        if len(supported_impls) != 0:
+            for flag in supported_impls:
+                base_kwargs[flag] = False
+            all_optim_inputs.append(
+                OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc)
+            )
+        else:
+            all_optim_inputs.append(optim_input)
+        # Add a config for when each of the global cliquey kwargs is True
+        # Note that in [optimizer kwarg categories], these kwargs are mutually
+        # exclusive, so we do not need to product them together.
+        for flag in supported_impls:
+            new_kwargs = deepcopy(base_kwargs)
+            new_kwargs[flag] = True
+            all_optim_inputs.append(
+                OptimizerInput(
+                    params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}"
+                )
+            )
+    return all_optim_inputs
+
+
+# Database of OptimizerInfo entries in alphabetical order.
+optim_db: list[OptimizerInfo] = [
+    OptimizerInfo(
+        Adadelta,
+        optim_inputs_func=optim_inputs_func_adadelta,
+        optim_error_inputs_func=optim_error_inputs_func_adadelta,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            # Note on tolerances:
+            # test_correctness_Adadelta_cuda_float32
+            # Mismatched elements: 10 / 100 (10.0%)
+            # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
+            # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
+            # This is due to floating point ordering error + usage of sqrt
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(
+                            rtol=5.5e-4,
+                            atol=5e-5,
+                        )
+                    }
+                ),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adafactor,
+        optim_inputs_func=optim_inputs_func_adafactor,
+        optim_error_inputs_func=optim_error_inputs_func_adafactor,
+        supported_impls=("foreach",),
+        not_og_supported_flags=("foreach",),
+        supports_complex=False,
+        skips=(
+            DecorateInfo(
+                unittest.skip("See #133268 regarding dtype being None"),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                device_type="cuda",
+                active_if=lambda kwargs: kwargs.get("use_closure", False),
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_foreach_large_tensor",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_foreach_matches_forloop",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_peak_memory_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_save_load_equality_with_weights_only",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028 regarding copy not supported"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_state_dict_deterministic",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                unittest.skip("See #133268 regarding dtype being None"),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_save_load_equality_with_weights_only",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_state_dict_deterministic",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+                device_type="xpu",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adagrad,
+        optim_inputs_func=optim_inputs_func_adagrad,
+        optim_error_inputs_func=optim_error_inputs_func_adagrad,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu",),
+        supports_sparse=True,
+        metadata_for_sparse=(
+            {"lr": 0.1, "weight_decay": 0, "lr_decay": 0},
+            [
+                lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
+                lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
+            ],
+        ),
+        decorators=(
+            DecorateInfo(
+                #  Note on tolerances:
+                #  difference comes from the fact that the non fused kernel have
+                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                #  to make sure there is no discrepancies between cuda fused kernel
+                #  and cpu fused kernel
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adam,
+        optim_inputs_func=optim_inputs_func_adam,
+        scheduler_inputs=(
+            [lambda opt: ExponentialLR(opt, gamma=0.9)],
+            [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
+            [
+                lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
+                lambda opt: ExponentialLR(opt, gamma=0.9),
+            ],
+            [
+                lambda opt: ExponentialLR(opt, gamma=0.9),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
+            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        optim_error_inputs_func=optim_error_inputs_func_adam,
+        supported_impls=("foreach", "differentiable", "fused"),
+        has_capturable_arg=True,
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu", "cuda", "xpu", "mps"),
+        decorators=(
+            # Expected floating point error between fused and compiled forloop
+            DecorateInfo(
+                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
+                and kwargs["dtype"] == torch.float64,
+            ),
+            DecorateInfo(
+                #  Note on tolerances:
+                #  difference comes from the fact that the non fused kernel have
+                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                #  to make sure there is no discrepancies between cuda fused kernel
+                #  and cpu fused kernel
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+            DecorateInfo(
+                # Note on tolerances:
+                # Tracking through #127000
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=3e-5, rtol=1.3e-06),
+                    }
+                ),
+                "TestCudaOptims",
+                "test_grad_scaling_autocast_fused_optimizers",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adamax,
+        optim_inputs_func=optim_inputs_func_adamax,
+        optim_error_inputs_func=optim_error_inputs_func_adamax,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                unittest.skip("Uses too much memory, even for H100, surprisingly."),
+                "TestOptimRenewed",
+                "test_foreach_large_tensor",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        AdamW,
+        optim_inputs_func=optim_inputs_func_adamw,
+        optim_error_inputs_func=optim_error_inputs_func_adamw,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu", "cuda", "mps"),
+        has_capturable_arg=True,
+        decorators=(
+            # Expected error between compiled forloop and fused optimizers
+            DecorateInfo(
+                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
+                and kwargs["dtype"] == torch.float64,
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    #  Note on tolerances:
+                    #  difference comes from the fact that the non fused kernel have
+                    #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                    #  to make sure there is no discrepancies between cuda fused kernel
+                    #  and cpu fused kernel
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+            # Note on tolerances:
+            # Tracking through #127000
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(
+                            atol=3e-5,
+                            rtol=1.3e-06,
+                        )
+                    }
+                ),
+                "TestCudaOptims",
+                "test_grad_scaling_autocast_fused_optimizers",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        ASGD,
+        optim_inputs_func=optim_inputs_func_asgd,
+        optim_error_inputs_func=optim_error_inputs_func_asgd,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1.5e-5, rtol=1e-5),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+            DecorateInfo(
+                unittest.skip(
+                    "ASGD internally changes the weights even with zero grad"
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        LBFGS,
+        optim_inputs_func=optim_inputs_func_lbfgs,
+        optim_error_inputs_func=optim_error_inputs_func_lbfgs,
+        supported_impls=(),
+        step_requires_closure=True,
+        supports_param_groups=False,
+        supports_multiple_devices=False,
+        skips=(
+            # Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094
+            DecorateInfo(
+                skipIfMPS,
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="mps",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.complex64: tol(
+                            rtol=4.5e-5,
+                            atol=5e-5,
+                        )
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+            ),
+            DecorateInfo(
+                unittest.skip("LBFGS doesn't support multidevice"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction_multigpu",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_group_with_lrscheduler_goes_right_direction",
+            ),
+            # https://github.com/pytorch/pytorch/issues/131398
+            DecorateInfo(
+                unittest.expectedFailure,
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                active_if=lambda kwargs: sys.platform == "darwin"
+                and kwargs["use_closure"],
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Muon,
+        optim_inputs_func=optim_inputs_func_muon,
+        optim_error_inputs_func=optim_error_inputs_func_muon,
+        supported_impls=(),
+        not_og_supported_flags=(),
+        supports_complex=False,
+        skips=(
+            # Note on numerical differences: `compile` applies different matmul tuning,
+            # which leads to deviations compared to eager mode. In the Newton-Schulz
+            # iteration for orthogonalization, computations are done in bfloat16, further
+            # amplifying these numerical differences.
+            DecorateInfo(
+                unittest.skip(
+                    "Expect high difference between compiled and eager due to bfloat16 and iterative process."
+                ),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        NAdam,
+        optim_inputs_func=optim_inputs_func_nadam,
+        optim_error_inputs_func=optim_error_inputs_func_nadam,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors, https://github.com/pytorch/pytorch/issues/117150"
+                ),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        RAdam,
+        optim_inputs_func=optim_inputs_func_radam,
+        optim_error_inputs_func=optim_error_inputs_func_radam,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        # previously atol=1e-7, rtol=1e-7
+                        torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_foreach_matches_forloop",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        RMSprop,
+        optim_inputs_func=optim_inputs_func_rmsprop,
+        optim_error_inputs_func=optim_error_inputs_func_rmsprop,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {  # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
+                        torch.float32: tol(atol=5e-04, rtol=0.01),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+                active_if=TEST_WITH_TORCHDYNAMO,
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Rprop,
+        optim_inputs_func=optim_inputs_func_rprop,
+        optim_error_inputs_func=optim_error_inputs_func_rprop,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        SGD,
+        optim_inputs_func=optim_inputs_func_sgd,
+        scheduler_inputs=(
+            [lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
+            [
+                lambda opt: LinearLR(
+                    opt, start_factor=0.4, end_factor=0.8, total_iters=4
+                )
+            ],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: LinearLR(
+                    opt, start_factor=0.4, end_factor=0.6, total_iters=4
+                ),
+            ],
+            [
+                lambda opt: StepLR(opt, gamma=0.99, step_size=10),
+                lambda opt: ExponentialLR(opt, gamma=0.99),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
+            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        optim_error_inputs_func=optim_error_inputs_func_sgd,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_sparse=True,
+        metadata_for_sparse=(
+            {
+                "lr": 4.8e-3,
+                "maximize": False,
+                "momentum": 0,
+                "nesterov": False,
+                "weight_decay": 0,
+            },
+            [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
+        ),
+        supports_fused_on=(
+            "cpu",
+            "cuda",
+            "xpu",
+            "mps",
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        SparseAdam,
+        optim_inputs_func=optim_inputs_func_sparseadam,
+        optim_error_inputs_func=optim_error_inputs_func_sparseadam,
+        supported_impls=(),
+        only_supports_sparse_grads=True,
+        metadata_for_sparse=({"lr": 4e-2}, []),
+        supports_complex=False,  # Missing complex support, see #118153
+        skips=(
+            DecorateInfo(
+                skipIfMPS,  # SparseAdam does not support MPS
+                "TestOptimRenewed",
+                device_type="mps",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_tensor_lr",
+            ),
+            DecorateInfo(
+                unittest.skip(
+                    "SparseAdam does not support dense gradients, see #116507"
+                ),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction_multigpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_param_group_with_lrscheduler_goes_right_direction",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_state_dict_with_cuda_params",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+            ),
+        ),
+    ),
+]
+
+
+class TensorTracker:
+    """
+    A utility to track tensor clones in a list, with the expectation of popping them later (in
+    order) to make fair comparisons between two multi-step computation. The intended use case is
+    usually when comparing two supposed equal computations, such as an optimizer step that each
+    individually consists of multiple steps, where numerical deviation could multiply.
+
+    The goal is to be able to compare and align numbers at every milestone so as to minimize
+    numerical discrepancies, and so when the test fails, it is likely a real problem.
+    """
+
+    def __init__(self, assert_eq_kwargs=None):
+        if assert_eq_kwargs is None:
+            assert_eq_kwargs = {}
+        self.assert_eq_kwargs = assert_eq_kwargs
+        self.tensors = []
+
+    def add(self, tensor):
+        """
+        Add a detach().clone()'d version of the tensor
+        """
+        self.tensors.append(tensor.detach().clone())
+
+    # pops from beginning, like a queue and not a stack!
+    def pop_check_set(self, tensor_to_set, testcase):
+        """
+        Pop the first element in the tensor tracker, assert equality between the popped tensor and
+        the input tensor, and then set the input tensor to have the same values as the popped tensor
+        (with copy_).
+        """
+        testcase.assertGreater(len(self.tensors), 0, "no tensors to pop")
+        ref = self.tensors.pop(0)
+
+        testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
+        testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
+
+        with torch.no_grad():
+            tensor_to_set.copy_(ref)
+
+    def all_popped(self):
+        return len(self.tensors) == 0
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f4fab8c48bbd84c631838b76ff8d7535046a98b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py
@@ -0,0 +1,3415 @@
+# mypy: ignore-errors
+
+r"""Importing this file includes common utility methods and base classes for
+checking quantization api and properties of resulting modules.
+"""
+
+import torch
+import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from functorch.experimental import control_flow
+from torch.ao.nn.intrinsic import _FusedModule
+from torch.ao.quantization import (
+    convert,
+    default_dynamic_qat_qconfig,
+    default_dynamic_qconfig,
+    default_dynamic_quant_observer,
+    default_embedding_qat_qconfig,
+    default_observer,
+    default_per_channel_qconfig,
+    default_qconfig,
+    default_symmetric_qnnpack_qat_qconfig,
+    default_weight_observer,
+    DeQuantStub,
+    float_qparams_weight_only_qconfig,
+    get_default_qat_qconfig,
+    get_default_qat_qconfig_mapping,
+    get_default_qconfig,
+    get_default_qconfig_mapping,
+    PerChannelMinMaxObserver,
+    propagate_qconfig_,
+    QConfig,
+    QConfigMapping,
+    quantize,
+    quantize_dynamic_jit,
+    quantize_jit,
+    QuantStub,
+    QuantType,
+    QuantWrapper,
+)
+from torch.ao.quantization.backend_config import get_executorch_backend_config
+from torch.ao.quantization.quantization_mappings import (
+    get_default_dynamic_quant_module_mappings,
+    get_default_qat_module_mappings,
+    get_default_qconfig_propagation_list,
+)
+from torch.ao.quantization.quantize_pt2e import (
+    _convert_to_reference_decomposed_fx,
+    convert_pt2e,
+    prepare_pt2e,
+    prepare_qat_pt2e,
+)
+from torch.ao.quantization.quantizer.xnnpack_quantizer import (
+    get_symmetric_quantization_config,
+    XNNPACKQuantizer,
+)
+
+from torch.export import export
+from torch.jit.mobile import _load_for_lite_interpreter
+from torch.testing._internal.common_quantized import override_quantized_engine
+from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase
+
+try:
+    from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
+
+    # graph mode quantization based on fx
+    from torch.ao.quantization.quantize_fx import (
+        convert_fx,
+        convert_to_reference_fx,
+        prepare_fx,
+        prepare_qat_fx,
+    )
+    from torch.fx import GraphModule
+    from torch.fx.graph import Node
+
+    HAS_FX = True
+except ImportError:
+    HAS_FX = False
+
+import contextlib
+import copy
+import functools
+import io
+import os
+
+import unittest
+from typing import Any, Optional, Union
+from collections.abc import Callable
+
+import numpy as np
+import torch._dynamo as torchdynamo
+import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
+import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq
+from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
+from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer
+from torch.testing import FileCheck
+
+
+class NodeSpec:
+    """Used for checking GraphModule Node"""
+
+    def __init__(self, op, target):
+        """
+        op: call_function | call_module
+        target:
+          for call_function, target would be a function
+          for call_module, target would be the type of PyTorch module
+        """
+        self.op = op
+        self.target = target
+
+    @classmethod
+    def call_function(cls, target):
+        return NodeSpec("call_function", target)
+
+    @classmethod
+    def call_method(cls, target):
+        return NodeSpec("call_method", target)
+
+    @classmethod
+    def call_module(cls, target):
+        return NodeSpec("call_module", target)
+
+    def __hash__(self):
+        return hash((self.op, self.target))
+
+    def __eq__(self, other):
+        if not isinstance(other, NodeSpec):
+            return NotImplemented
+
+        return self.op == other.op and self.target == other.target
+
+    def __repr__(self):
+        return repr(self.op) + " " + repr(self.target)
+
+
+def get_supported_device_types():
+    return (
+        ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"]
+    )
+
+
+def test_only_eval_fn(model, calib_data):
+    r"""
+    Default evaluation function takes a torch.utils.data.Dataset or a list of
+    input Tensors and run the model on the dataset
+    """
+    for inp in calib_data:
+        model(*inp)
+
+
+_default_loss_fn = torch.nn.CrossEntropyLoss()
+
+
+def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn):
+    r"""
+    Default train function takes a torch.utils.data.Dataset and train the model
+    on the dataset
+    """
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+    train_loss, correct, total = 0, 0, 0
+    for _ in range(10):
+        model.train()
+
+        for data, target in train_data:
+            optimizer.zero_grad()
+            output = model(data)
+            loss = loss_fn(output, target)
+            loss.backward()
+            optimizer.step()
+            train_loss += loss.item()
+            _, predicted = torch.max(output, 1)
+            total += target.size(0)
+            correct += (predicted == target).sum().item()
+    return train_loss, correct, total
+
+
+class AverageMeter:
+    """Computes and stores the average and current value"""
+
+    def __init__(self, name, fmt=":f"):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+        return fmtstr.format(**self.__dict__)
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
+    model.train()
+    for cnt, (image, target) in enumerate(data_loader, start=1):
+        print(".", end="")
+        image, target = image.to(device), target.to(device)
+        output = model(image)
+        loss = criterion(output, target)
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        accuracy(output, target, topk=(1, 5))
+        if cnt >= ntrain_batches:
+            return
+    return
+
+
+def ddp_setup(rank, world_size):
+    os.environ["MASTER_ADDR"] = "localhost"
+    os.environ["MASTER_PORT"] = "12355"
+
+    # initialize the process group
+    dist.init_process_group("gloo", rank=rank, world_size=world_size)
+
+
+def ddp_cleanup():
+    dist.destroy_process_group()
+
+
+def run_ddp(rank, world_size, prepared):
+    ddp_setup(rank, world_size)
+    prepared.cuda()
+    prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank])
+    prepared.to(rank)
+    model_with_ddp = prepared
+    optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
+    train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1)  # noqa: F821
+    ddp_cleanup()
+
+
+def convert_dynamic(module):
+    convert(module, get_default_dynamic_quant_module_mappings(), inplace=True)
+
+
+def prepare_dynamic(model, qconfig_dict=None):
+    propagate_qconfig_(model, qconfig_dict)
+
+
+def _make_conv_test_input(
+    batch_size,
+    in_channels_per_group,
+    input_feature_map_size,
+    out_channels_per_group,
+    groups,
+    kernel_size,
+    X_scale,
+    X_zero_point,
+    W_scale,
+    W_zero_point,
+    use_bias,
+    use_channelwise,
+):
+    in_channels = in_channels_per_group * groups
+    out_channels = out_channels_per_group * groups
+
+    (X_value_min, X_value_max) = (0, 4)
+    X_init = torch.randint(
+        X_value_min,
+        X_value_max,
+        (
+            batch_size,
+            in_channels,
+        )
+        + input_feature_map_size,
+    )
+    X = X_scale * (X_init - X_zero_point).float()
+    X_q = torch.quantize_per_tensor(
+        X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8
+    )
+
+    W_scale = W_scale * out_channels
+    W_zero_point = W_zero_point * out_channels
+    # Resize W_scale and W_zero_points arrays equal to out_channels
+    W_scale = W_scale[:out_channels]
+    W_zero_point = W_zero_point[:out_channels]
+    # For testing, we use small values for weights and for activations so that
+    # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
+    # qconv implementation and if there is no overflow.
+    # In reference we can't exactly match the results with reference.
+    # Please see the comment in qconv implementation file
+    #   aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
+    (W_value_min, W_value_max) = (-5, 5)
+    # The operator expects them in the format
+    # (out_channels, in_channels/groups,) + kernel_size
+    W_init = torch.randint(
+        W_value_min,
+        W_value_max,
+        (
+            out_channels,
+            in_channels_per_group,
+        )
+        + kernel_size,
+    )
+    b_init = torch.randint(0, 10, (out_channels,))
+
+    if use_channelwise:
+        W_shape = (-1, 1) + (1,) * len(kernel_size)
+        W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
+        W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
+        W = (
+            W_scales_tensor.reshape(*W_shape)
+            * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
+        )
+        b = X_scale * W_scales_tensor * b_init.float()
+        W_q = torch.quantize_per_channel(
+            W,
+            W_scales_tensor.double(),
+            W_zero_points_tensor.long(),
+            0,
+            dtype=torch.qint8,
+        )
+    else:
+        W = W_scale[0] * (W_init - W_zero_point[0]).float()
+        b = X_scale * W_scale[0] * b_init.float()
+        W_q = torch.quantize_per_tensor(
+            W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8
+        )
+
+    return (X, X_q, W, W_q, b if use_bias else None)
+
+
+def _make_conv_add_extra_input_tensor(scale, zero_point, sizes):
+    (X_value_min, X_value_max) = (0, 4)
+    X_init = torch.randint(
+        X_value_min,
+        X_value_max,
+        sizes,  # Infer the size of tensor to do the add
+    )
+    X = scale * (X_init - zero_point).float()
+    X_q = torch.quantize_per_tensor(
+        X, scale=scale, zero_point=zero_point, dtype=torch.quint8
+    )
+    return X, X_q
+
+
+def skipIfNoFBGEMM(fn):
+    reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer."
+    if isinstance(fn, type):
+        if "fbgemm" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "fbgemm" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoQNNPACK(fn):
+    reason = "Quantized operations require QNNPACK."
+    if isinstance(fn, type):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def withQNNPACKBackend(fn):
+    # TODO(future PR): consider combining with skipIfNoQNNPACK,
+    # will require testing of existing callsites
+    reason = "Quantized operations require QNNPACK."
+    if isinstance(fn, type):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        with override_quantized_engine("qnnpack"):
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoONEDNN(fn):
+    reason = "Quantized operations require ONEDNN."
+    if isinstance(fn, type):
+        if "onednn" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "onednn" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoONEDNNBF16(fn):
+    reason = "Quantized operations require BF16 support."
+    if isinstance(fn, type):
+        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoX86(fn):
+    reason = "Quantized operations require X86."
+    if isinstance(fn, type):
+        if "x86" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "x86" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoDynamoSupport(fn):
+    reason = "dynamo doesn't support."
+    if isinstance(fn, type):
+        if not torchdynamo.is_dynamo_supported():
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torchdynamo.is_dynamo_supported():
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoInductorSupport(fn):
+    reason = "inductor doesn't support."
+    if isinstance(fn, type):
+        if not torchdynamo.is_inductor_supported():
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torchdynamo.is_inductor_supported():
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+try:
+    import torchvision  # noqa: F401
+
+    HAS_TORCHVISION = True
+except ImportError:
+    HAS_TORCHVISION = False
+skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
+
+
+def get_script_module(model, tracing, data):
+    return torch.jit.trace(model, data) if tracing else torch.jit.script(model)
+
+
+def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True):
+    """
+    Convert lengths to offsets for embedding_bag
+    """
+    tt = np.zeros((t.shape[0] + 1,), dtype=offset_type)
+    tt[1:] = t
+    tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type))
+    if use_begin_offset:
+        return tt[:-1]
+    return tt[1:]
+
+
+def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
+    assert w.dim() == 2
+    w = w.transpose(0, 1).contiguous()
+    assert q_group_size > 1
+    assert w.shape[-1] % q_group_size == 0
+
+    to_quant = w.reshape(-1, q_group_size)
+    assert torch.isnan(to_quant).sum() == 0
+
+    max_val = to_quant.amax(dim=1, keepdim=True)
+    min_val = to_quant.amin(dim=1, keepdim=True)
+    max_int = 2**n_bit - 1
+    min_int = 0
+    scales = (max_val - min_val).clamp(min=1e-6) / max_int
+    assert torch.isnan(scales).sum() == 0
+
+    zeros = min_val + scales * (2 ** (n_bit - 1))
+    assert torch.isnan(zeros).sum() == 0
+
+    out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
+    assert torch.isnan(out).sum() == 0
+
+    out = out.to(dtype=torch.int32).reshape(w.shape)
+    if out.device != torch.device("cpu"):
+        out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8)
+
+    # Scales and zeros for the same q-group should be contiguous, so we can
+    # load as a 32-bit word
+    scales = scales.view(w.shape[0], -1)
+    zeros = zeros.view(w.shape[0], -1)
+    scales_and_zeros = (
+        torch.cat(
+            [
+                scales.reshape(scales.size(0), scales.size(1), 1),
+                zeros.reshape(zeros.size(0), zeros.size(1), 1),
+            ],
+            2,
+        )
+        .transpose(0, 1)
+        .contiguous()
+    )
+
+    return out, scales_and_zeros
+
+
+def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32):
+    # W is of shape [K x N]
+    # We transpose W as Quantization is applied on [N x K]
+    w = w.transpose(0, 1).contiguous()
+    assert w.dim() == 2
+    assert groupsize > 1
+    assert w.shape[-1] % groupsize == 0
+    # Calculate scale and zeros
+    to_quant = w.reshape(-1, groupsize)
+    max_val = to_quant.abs().amax(dim=1, keepdim=True)
+    eps = torch.finfo(max_val.dtype).eps
+    max_int = 2 ** (n_bit - 1) - 1  # For 4-bit, this is 7
+    scales = max_val.clamp(min=eps) / max_int
+    zeros = torch.zeros_like(scales)
+
+    # Quantize the weight
+    scales = scales.to(torch.float32).reshape(w.shape[0], -1)
+    zeros = zeros.to(torch.float32).reshape(w.shape[0], -1)
+    scales = scales.reshape(-1, 1)
+    zeros = zeros.reshape(-1, 1)
+    max_int = 2**n_bit - 1
+    w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int)
+    # We pack 2 signed int4 values in unsigned uint8 container.
+    # This reduces the weight size by half and improves load perf
+    out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8)
+
+    scales_and_zeros = scales.squeeze().contiguous()
+
+    return out_uint8, scales_and_zeros
+
+
+def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+    # source: https://github.com/meta-pytorch/gpt-fast/blob/main/quantize.py
+    # default setup for affine quantization of activations
+    x_dtype = x.dtype
+    x = x.float()
+    eps = torch.finfo(torch.float32).eps
+
+    # get min and max
+    min_val, max_val = torch.aminmax(x, dim=1)
+
+    # calculate scales and zero_points based on min and max
+    # reference: https://fburl.com/code/srbiybme
+    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+    device = min_val_neg.device
+
+    # reference: https://fburl.com/code/4wll53rk
+    max_val_pos = torch.max(-min_val_neg, max_val_pos)
+    scales = max_val_pos / (float(quant_max - quant_min) / 2)
+    # ensure scales is the same dtype as the original tensor
+    scales = torch.clamp(scales, min=eps).to(x.dtype)
+    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+    # quantize based on qmin/qmax/scales/zp
+    x_div = x / scales.unsqueeze(-1)
+    x_round = torch.round(x_div)
+    x_zp = x_round + zero_points.unsqueeze(-1)
+    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+    return quant, scales.to(x_dtype), zero_points
+
+
+# QuantizationTestCase used as a base class for testing quantization on modules
+class QuantizationTestCase(TestCase):
+    def setUp(self):
+        super().setUp()
+        self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)]
+        self.train_data = [
+            [
+                torch.rand(2, 5, dtype=torch.float),
+                torch.randint(0, 1, (2,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+        self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)]
+        self.img_data_2d = [
+            [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2)
+        ]
+        self.img_data_3d = [
+            [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2)
+        ]
+        self.img_data_1d_train = [
+            [
+                torch.rand(2, 3, 10, dtype=torch.float),
+                torch.randint(0, 1, (1,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+        self.img_data_2d_train = [
+            [
+                torch.rand(1, 3, 10, 10, dtype=torch.float),
+                torch.randint(0, 1, (1,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+        self.img_data_3d_train = [
+            [
+                torch.rand(1, 3, 5, 5, 5, dtype=torch.float),
+                torch.randint(0, 1, (1,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+
+        self.img_data_dict = {
+            1: self.img_data_1d,
+            2: self.img_data_2d,
+            3: self.img_data_3d,
+        }
+
+        # Quant types that produce statically quantized ops
+        self.static_quant_types = [QuantType.STATIC, QuantType.QAT]
+        # All quant types for (fx based) graph mode quantization
+        self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT]
+
+    def checkNoPrepModules(self, module):
+        r"""Checks the module does not contain child
+        modules for quantization preparation, e.g.
+        quant, dequant and observer
+        """
+        self.assertFalse(hasattr(module, "quant"))
+        self.assertFalse(hasattr(module, "dequant"))
+
+    def checkNoQconfig(self, module):
+        r"""Checks the module does not contain qconfig"""
+        self.assertFalse(hasattr(module, "qconfig"))
+
+        for child in module.children():
+            self.checkNoQconfig(child)
+
+    def checkHasPrepModules(self, module):
+        r"""Checks the module contains child
+        modules for quantization preparation, e.g.
+        quant, dequant and observer
+        """
+        self.assertTrue(hasattr(module, "module"))
+        self.assertTrue(hasattr(module, "quant"))
+        self.assertTrue(hasattr(module, "dequant"))
+
+    def checkObservers(
+        self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None
+    ):
+        r"""Checks the module or module's leaf descendants
+        have observers in preparation for quantization
+        """
+        if propagate_qconfig_list is None:
+            propagate_qconfig_list = get_default_qconfig_propagation_list()
+        if prepare_custom_config_dict is None:
+            prepare_custom_config_dict = {}
+        float_to_observed_module_class_mapping = prepare_custom_config_dict.get(
+            "float_to_observed_custom_module_class", {}
+        )
+
+        # check if a module is a leaf module, ignoring activation_post_process attribute
+        def is_leaf_module(module):
+            submodule_name_count = 0
+            for name, _ in module.named_children():
+                if name != "activation_post_process":
+                    submodule_name_count += 1
+            return submodule_name_count == 0
+
+        if (
+            hasattr(module, "qconfig")
+            and module.qconfig is not None
+            and (
+                (
+                    is_leaf_module(module)
+                    and not isinstance(module, torch.nn.Sequential)
+                    and type(module) in propagate_qconfig_list
+                )
+                or type(module) in float_to_observed_module_class_mapping
+            )
+            and not isinstance(module, torch.ao.quantization.DeQuantStub)
+        ):
+            self.assertTrue(
+                hasattr(module, "activation_post_process"),
+                "module: " + str(type(module)) + " do not have observer",
+            )
+        # we don't need to check observers for child modules of the
+        # qat modules
+        if (
+            type(module) not in get_default_qat_module_mappings().values()
+            and type(module) not in float_to_observed_module_class_mapping.values()
+            and not isinstance(module, _FusedModule)
+        ):
+            for child in module.children():
+                if type(child) is nn.Dropout:
+                    continue
+                self.checkObservers(
+                    child, propagate_qconfig_list, prepare_custom_config_dict
+                )
+
+    def checkQuantDequant(self, mod):
+        r"""Checks that mod has nn.Quantize and
+        nn.DeQuantize submodules inserted
+        """
+        self.assertEqual(type(mod.quant), nnq.Quantize)
+        self.assertEqual(type(mod.dequant), nnq.DeQuantize)
+
+    def checkWrappedQuantizedLinear(self, mod):
+        r"""Checks that mod has been swapped for an nnq.Linear
+        module, the bias is qint32, and that the module
+        has Quantize and DeQuantize submodules
+        """
+        self.assertEqual(type(mod.module), nnq.Linear)
+        self.checkQuantDequant(mod)
+
+    def checkQuantizedLinear(self, mod):
+        self.assertEqual(type(mod), nnq.Linear)
+
+    def checkDynamicQuantizedLinear(self, mod, dtype):
+        r"""Checks that mod has been swapped for an nnqd.Linear
+        module, the bias is float.
+        """
+        self.assertEqual(type(mod), nnqd.Linear)
+        self.assertEqual(mod._packed_params.dtype, dtype)
+
+    def checkDynamicQuantizedLinearRelu(self, mod, dtype):
+        r"""Checks that mod has been swapped for an nnqd.Linear
+        module, the bias is float.
+        """
+        self.assertEqual(type(mod), nniqd.LinearReLU)
+        self.assertEqual(mod._packed_params.dtype, dtype)
+
+    def check_eager_serialization(self, ref_model, loaded_model, x):
+        # Check state dict serialization and torch.save APIs
+        model_dict = ref_model.state_dict()
+        b = io.BytesIO()
+        torch.save(model_dict, b)
+        b.seek(0)
+        # weights_only=False as we sometimes get a ScriptObject here (weird)
+        loaded_dict = torch.load(b, weights_only=False)
+        loaded_model.load_state_dict(loaded_dict)
+        ref_out = ref_model(*x)
+        load_out = loaded_model(*x)
+
+        def check_outputs(ref_out, load_out):
+            self.assertEqual(ref_out[0], load_out[0])
+            if isinstance(ref_out[1], tuple):
+                self.assertEqual(ref_out[1][0], load_out[1][0])
+                self.assertEqual(ref_out[1][1], load_out[1][1])
+            else:
+                self.assertEqual(ref_out[1], load_out[1])
+
+        check_outputs(ref_out, load_out)
+        b = io.BytesIO()
+        torch.save(ref_model, b)
+        b.seek(0)
+        # weights_only=False as this is legacy code that saves the model
+        loaded = torch.load(b, weights_only=False)
+        load_out = loaded(*x)
+        check_outputs(ref_out, load_out)
+
+    def check_weight_bias_api(self, ref_model, weight_keys, bias_keys):
+        weight = ref_model.get_weight()
+        bias = ref_model.get_bias()
+        self.assertEqual(weight_keys ^ weight.keys(), set())
+        self.assertEqual(bias_keys ^ bias.keys(), set())
+
+    def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype):
+        r"""Checks that mod has been swapped for an nnqd.LSTM type
+        module, the bias is float.
+        """
+        wt_dtype_map = {
+            torch.qint8: "quantized_dynamic",
+            torch.float16: "quantized_fp16",
+        }
+        self.assertEqual(type(mod), reference_module_type)
+        for packed_params in mod._all_weight_values:
+            self.assertEqual(
+                packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]
+            )
+
+    def checkLinear(self, mod):
+        self.assertEqual(type(mod), torch.nn.Linear)
+
+    def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype):
+        r"""Checks that mod has been swapped for an nnqd.Linear
+        module, the bias is float.
+        """
+        wt_dtype_map = {
+            torch.qint8: "quantized_dynamic",
+            torch.float16: "quantized_fp16",
+        }
+        self.assertEqual(type(mod), reference_module_type)
+        if hasattr(mod, "_all_weight_values"):
+            for packed_params in mod._all_weight_values:
+                self.assertEqual(
+                    packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]
+                )
+
+    def checkScriptable(self, orig_mod, calib_data, check_save_load=False):
+        scripted = torch.jit.script(orig_mod)
+        self._checkScriptable(orig_mod, scripted, calib_data, check_save_load)
+
+        # Use first calib_data entry as trace input
+        traced = torch.jit.trace(orig_mod, calib_data[0])
+        self._checkScriptable(orig_mod, traced, calib_data, check_save_load)
+
+    # Call this twice: once for a scripted module and once for a traced module
+    def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load):
+        self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data)
+
+        # Test save/load
+        buffer = io.BytesIO()
+        torch.jit.save(script_mod, buffer)
+
+        buffer.seek(0)
+        loaded_mod = torch.jit.load(buffer)
+        # Pending __get_state_ and __set_state__ support
+        # See tracking task https://github.com/pytorch/pytorch/issues/23984
+        if check_save_load:
+            self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data)
+
+    def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data):
+        for inp in calib_data:
+            ref_output = orig_mod(*inp)
+            scripted_output = test_mod(*inp)
+            self.assertEqual(scripted_output, ref_output)
+
+    def checkGraphModeOp(
+        self,
+        module,
+        inputs,
+        quantized_op,
+        tracing=False,
+        debug=False,
+        check=True,
+        eval_mode=True,
+        dynamic=False,
+        qconfig=None,
+    ):
+        if debug:
+            print("Testing:", str(module))
+        qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
+
+        if eval_mode:
+            module = module.eval()
+        if dynamic:
+            qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig}
+        model = get_script_module(module, tracing, inputs[0]).eval()
+        if debug:
+            print("input graph:", model.graph)
+        models = {}
+        outputs = {}
+        for debug in [True, False]:
+            if dynamic:
+                models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug)
+                # make sure it runs
+                outputs[debug] = models[debug](inputs)
+            else:
+                # module under test can contain in-place ops, and we depend on
+                # input data staying constant for comparisons
+                inputs_copy = copy.deepcopy(inputs)
+                models[debug] = quantize_jit(
+                    model,
+                    qconfig_dict,
+                    test_only_eval_fn,
+                    [inputs_copy],
+                    inplace=False,
+                    debug=debug,
+                )
+                # make sure it runs
+                outputs[debug] = models[debug](*inputs[0])
+
+        if debug:
+            print("debug graph:", models[True].graph)
+            print("non debug graph:", models[False].graph)
+
+        if check:
+            # debug and non-debug option should have the same numerics
+            self.assertEqual(outputs[True], outputs[False])
+
+            # non debug graph should produce quantized op
+            FileCheck().check(quantized_op).run(models[False].graph)
+
+        return models[False]
+
+    def checkGraphModuleNodes(
+        self,
+        graph_module,
+        expected_node=None,
+        expected_node_occurrence=None,
+        expected_node_list=None,
+    ):
+        """Check if GraphModule contains the target node
+        Args:
+            graph_module: the GraphModule instance we want to check
+            expected_node, expected_node_occurrence, expected_node_list:
+               see docs for checkGraphModeFxOp
+        """
+        nodes_in_graph = {}
+        node_list = []
+        modules = dict(graph_module.named_modules(remove_duplicate=False))
+        for node in graph_module.graph.nodes:
+            n = None
+            if node.op == "call_function" or node.op == "call_method":
+                n = NodeSpec(node.op, node.target)
+            elif node.op == "call_module":
+                n = NodeSpec(node.op, type(modules[node.target]))
+
+            if n is not None:
+                node_list.append(n)
+                if n in nodes_in_graph:
+                    nodes_in_graph[n] += 1
+                else:
+                    nodes_in_graph[n] = 1
+
+        if expected_node is not None:
+            self.assertTrue(
+                expected_node in nodes_in_graph,
+                "node:" + str(expected_node) + " not found in the graph module",
+            )
+
+        if expected_node_occurrence is not None:
+            for expected_node, occurrence in expected_node_occurrence.items():
+                if occurrence != 0:
+                    self.assertTrue(
+                        expected_node in nodes_in_graph,
+                        "Check failed for node:" + str(expected_node) + " not found",
+                    )
+                    self.assertTrue(
+                        nodes_in_graph[expected_node] == occurrence,
+                        "Check failed for node:"
+                        + str(expected_node)
+                        + " Expected occurrence:"
+                        + str(occurrence)
+                        + " Found occurrence:"
+                        + str(nodes_in_graph[expected_node]),
+                    )
+                else:
+                    self.assertTrue(
+                        expected_node not in nodes_in_graph,
+                        "Check failed for node:"
+                        + str(expected_node)
+                        + " expected no occurrence but found",
+                    )
+
+        if expected_node_list is not None:
+            cur_index = 0
+            for n in node_list:
+                if cur_index == len(expected_node_list):
+                    return
+                if n == expected_node_list[cur_index]:
+                    cur_index += 1
+            self.assertTrue(
+                cur_index == len(expected_node_list),
+                "Check failed for graph:"
+                + self.printGraphModule(graph_module, print_str=False)
+                + "Expected ordered list:"
+                + str(expected_node_list),
+            )
+
+    def printGraphModule(self, graph_module, print_str=True):
+        modules = dict(graph_module.named_modules(remove_duplicate=False))
+        node_infos = []
+        for n in graph_module.graph.nodes:
+            node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs]))
+            if n.op == "call_module":
+                node_info += " module type: " + repr(type(modules[n.target]))
+            node_infos.append(node_info)
+        str_to_print = "\n".join(node_infos)
+        if print_str:
+            print(str_to_print)
+        return str_to_print
+
+    if HAS_FX:
+
+        def assert_types_for_matched_subgraph_pairs(
+            self,
+            matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]],
+            expected_types: dict[
+                str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]
+            ],
+            gm_a: GraphModule,
+            gm_b: GraphModule,
+        ) -> None:
+            """
+            Verifies that the types specified in expected_types match
+            the underlying objects pointed to by the nodes in matched_subgraph_pairs.
+
+            An example successful test case:
+
+              matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)}
+              expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)}
+
+            The function tests for key equivalence, and verifies types with
+            instance checks.
+            """
+
+            def _get_underlying_op_type(
+                node: Node, gm: GraphModule
+            ) -> Union[Callable, str]:
+                if node.op == "call_module":
+                    mod = getattr(gm, node.target)
+                    return type(mod)
+                else:
+                    assert node.op in ("call_function", "call_method")
+                    return node.target
+
+            self.assertTrue(
+                len(matched_subgraph_pairs) == len(expected_types),
+                f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}",
+            )
+            for k, v in expected_types.items():
+                expected_types_a, expected_types_b = v
+                exp_type_start_a, exp_type_end_a = expected_types_a
+                exp_type_start_b, exp_type_end_b = expected_types_b
+                subgraph_a, subgraph_b = matched_subgraph_pairs[k]
+
+                act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a)
+                act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b)
+                act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a)
+                act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b)
+                types_match = (
+                    (exp_type_start_a is act_type_start_a)
+                    and (exp_type_end_a is act_type_end_a)
+                    and (exp_type_start_b is act_type_start_b)
+                    and (exp_type_end_b is act_type_end_b)
+                )
+                self.assertTrue(
+                    types_match,
+                    f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, "
+                    f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}",
+                )
+
+        def assert_ns_compare_dict_valid(
+            self,
+            act_compare_dict: dict[str, dict[str, dict[str, Any]]],
+        ) -> None:
+            """
+            Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
+            1. for each layer, results are recorded for two models
+            2. number of seen tensors match
+            3. shapes of each pair of seen tensors match
+            """
+            for layer_name, result_type_to_data in act_compare_dict.items():
+                for result_type, layer_data in result_type_to_data.items():
+                    self.assertTrue(
+                        len(layer_data) == 2,
+                        f"Layer {layer_name} does not have exactly two model results.",
+                    )
+                    model_name_0, model_name_1 = layer_data.keys()
+                    for res_idx in range(len(layer_data[model_name_0])):
+                        layer_data_0 = layer_data[model_name_0][res_idx]
+                        layer_data_1 = layer_data[model_name_1][res_idx]
+                        self.assertTrue(
+                            layer_data_0["type"] == layer_data_0["type"],
+                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.",
+                        )
+
+                        self.assertTrue(
+                            len(layer_data_0["values"]) == len(layer_data_1["values"]),
+                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.",
+                        )
+
+                        # F.conv1d weight has rank 3, and toq.conv1d unpacked weight
+                        # has rank 4. For now, skip the length check for conv1d only.
+                        is_weight_functional_conv1d = (
+                            result_type == NSSingleResultValuesType.WEIGHT.value
+                            and (
+                                "conv1d" in layer_data_0["prev_node_target_type"]
+                                or "conv1d" in layer_data_1["prev_node_target_type"]
+                            )
+                        )
+                        if not is_weight_functional_conv1d:
+                            for idx in range(len(layer_data_0["values"])):
+                                values_0 = layer_data_0["values"][idx]
+                                values_1 = layer_data_1["values"][idx]
+                                if isinstance(values_0, torch.Tensor):
+                                    self.assertTrue(
+                                        values_0.shape == values_1.shape,
+                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} "
+                                        + f"have a shape mismatch at idx {idx}.",
+                                    )
+                                elif isinstance(values_0, list):
+                                    values_0 = values_0[0]
+                                    values_1 = values_1[0]
+                                    self.assertTrue(
+                                        values_0.shape == values_1.shape,
+                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} "
+                                        + f"have a shape mismatch at idx {idx}.",
+                                    )
+                                else:
+                                    assert isinstance(
+                                        values_0, tuple
+                                    ), f"unhandled type {type(values_0)}"
+                                    assert len(values_0) == 2
+                                    assert len(values_0[1]) == 2
+                                    assert values_0[0].shape == values_1[0].shape
+                                    assert values_0[1][0].shape == values_1[1][0].shape
+                                    assert values_0[1][1].shape == values_1[1][1].shape
+
+                        # verify that ref_node_name is valid
+                        ref_node_name_0 = layer_data_0["ref_node_name"]
+                        ref_node_name_1 = layer_data_1["ref_node_name"]
+                        prev_node_name_0 = layer_data_0["prev_node_name"]
+                        prev_node_name_1 = layer_data_1["prev_node_name"]
+                        if (
+                            layer_data_0["type"]
+                            == NSSingleResultValuesType.NODE_OUTPUT.value
+                        ):
+                            self.assertTrue(ref_node_name_0 == prev_node_name_0)
+                            self.assertTrue(ref_node_name_1 == prev_node_name_1)
+                        elif (
+                            layer_data_0["type"]
+                            == NSSingleResultValuesType.NODE_INPUT.value
+                        ):
+                            self.assertTrue(ref_node_name_0 != prev_node_name_0)
+                            self.assertTrue(ref_node_name_1 != prev_node_name_1)
+
+        def checkGraphModeFxOp(
+            self,
+            model,
+            inputs,
+            quant_type,
+            expected_node=None,
+            expected_node_occurrence=None,
+            expected_node_list=None,
+            is_reference=False,
+            print_debug_info=False,
+            custom_qconfig_dict=None,
+            prepare_expected_node=None,
+            prepare_expected_node_occurrence=None,
+            prepare_expected_node_list=None,
+            prepare_custom_config=None,
+            backend_config=None,
+        ):
+            """Quantizes model with graph mode quantization on fx and check if the
+            quantized model contains the quantized_node
+
+            Args:
+                model: floating point torch.nn.Module
+                inputs: one positional sample input arguments for model
+                expected_node: NodeSpec
+                    e.g. NodeSpec.call_function(torch.quantize_per_tensor)
+                expected_node_occurrence: a dict from NodeSpec to
+                    expected number of occurrences (int)
+                    e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
+                            NodeSpec.call_method('dequantize'): 1}
+                expected_node_list: a list of NodeSpec, used to check the order
+                    of the occurrence of Node
+                    e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
+                            NodeSpec.call_module(nnq.Conv2d),
+                            NodeSpec.call_function(F.hardtanh_),
+                            NodeSpec.call_method('dequantize')]
+                is_reference: if True, enables reference mode
+                print_debug_info: if True, prints debug info
+                custom_qconfig_dict: overrides default qconfig_dict
+                prepare_expected_node: same as expected_node, but for prepare
+                prepare_expected_node_occurrence: same as
+                    expected_node_occurrence, but for prepare
+                prepare_expected_node_list: same as expected_node_list, but
+                    for prepare
+
+            Returns:
+                A dictionary with the following structure:
+               {
+                   "prepared": ...,  # the prepared model
+                   "quantized": ...,  # the quantized non-reference model
+                   "quantized_reference": ...,  # the quantized reference model
+                   "result": ...,  # the result for either quantized or
+                                   # quantized_reference model depending on the
+                                   # is_reference argument
+               }
+            """
+            # TODO: make img_data a single example instead of a list
+            if type(inputs) is list:
+                inputs = inputs[0]
+
+            if quant_type == QuantType.QAT:
+                qconfig_mapping = get_default_qat_qconfig_mapping(
+                    torch.backends.quantized.engine
+                )
+                model.train()
+            elif quant_type == QuantType.STATIC:
+                qconfig_mapping = get_default_qconfig_mapping(
+                    torch.backends.quantized.engine
+                )
+                model.eval()
+            else:
+                qconfig = default_dynamic_qconfig
+                qconfig_mapping = QConfigMapping().set_global(qconfig)
+                model.eval()
+
+            if quant_type == QuantType.QAT:
+                prepare = prepare_qat_fx
+            else:
+                prepare = prepare_fx
+
+            # overwrite qconfig_dict with custom_qconfig_dict
+            if custom_qconfig_dict is not None:
+                assert type(custom_qconfig_dict) in (
+                    QConfigMapping,
+                    dict,
+                ), "custom_qconfig_dict should be a QConfigMapping or a dict"
+                if isinstance(custom_qconfig_dict, QConfigMapping):
+                    qconfig_mapping = custom_qconfig_dict
+                else:
+                    qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict)
+            prepared = prepare(
+                model,
+                qconfig_mapping,
+                example_inputs=inputs,
+                prepare_custom_config=prepare_custom_config,
+                backend_config=backend_config,
+            )
+            if quant_type != QuantType.DYNAMIC:
+                prepared(*inputs)
+
+            if print_debug_info:
+                print()
+                print("quant type:\n", quant_type)
+                print("original model:\n", model)
+                print()
+                print("prepared model:\n", prepared)
+
+            self.checkGraphModuleNodes(
+                prepared,
+                prepare_expected_node,
+                prepare_expected_node_occurrence,
+                prepare_expected_node_list,
+            )
+
+            prepared_copy = copy.deepcopy(prepared)
+            qgraph = convert_fx(copy.deepcopy(prepared))
+            qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared))
+            result = qgraph(*inputs)
+            result_reference = qgraph_reference(*inputs)
+            qgraph_copy = copy.deepcopy(qgraph)
+            qgraph_reference_copy = copy.deepcopy(qgraph_reference)
+
+            qgraph_to_check = qgraph_reference if is_reference else qgraph
+            if print_debug_info:
+                print()
+                print("quantized model:\n", qgraph_to_check)
+                self.printGraphModule(qgraph_to_check)
+                print()
+            self.checkGraphModuleNodes(
+                qgraph_to_check,
+                expected_node,
+                expected_node_occurrence,
+                expected_node_list,
+            )
+            return {
+                "prepared": prepared_copy,
+                "quantized": qgraph_copy,
+                "quantized_reference": qgraph_reference_copy,
+                "quantized_output": result,
+                "quantized_reference_output": result_reference,
+            }
+
+    def checkEmbeddingSerialization(
+        self,
+        qemb,
+        num_embeddings,
+        embedding_dim,
+        indices,
+        offsets,
+        set_qconfig,
+        is_emb_bag,
+        dtype=torch.quint8,
+    ):
+        # Test serialization of dynamic EmbeddingBag module using state_dict
+        if is_emb_bag:
+            inputs = [indices, offsets]
+        else:
+            inputs = [indices]
+        emb_dict = qemb.state_dict()
+        b = io.BytesIO()
+        torch.save(emb_dict, b)
+        b.seek(0)
+        loaded_dict = torch.load(b)
+        embedding_unpack = torch.ops.quantized.embedding_bag_unpack
+        # Check unpacked weight values explicitly
+        for key in emb_dict:
+            if isinstance(emb_dict[key], torch._C.ScriptObject):
+                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
+                emb_weight = embedding_unpack(emb_dict[key])
+                loaded_weight = embedding_unpack(loaded_dict[key])
+                self.assertEqual(emb_weight, loaded_weight)
+
+        # Check state dict serialization and torch.save APIs
+        if is_emb_bag:
+            loaded_qemb = nnq.EmbeddingBag(
+                num_embeddings=num_embeddings,
+                embedding_dim=embedding_dim,
+                include_last_offset=True,
+                mode="sum",
+                dtype=dtype,
+            )
+        else:
+            loaded_qemb = nnq.Embedding(
+                num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype
+            )
+        self.check_eager_serialization(qemb, loaded_qemb, inputs)
+
+        loaded_qemb.load_state_dict(loaded_dict)
+        self.assertEqual(
+            embedding_unpack(qemb._packed_params._packed_weight),
+            embedding_unpack(loaded_qemb._packed_params._packed_weight),
+        )
+
+        # Test JIT serialization
+        self.checkScriptable(qemb, [inputs], check_save_load=True)
+
+        # Test from_float call
+        if is_emb_bag:
+            float_embedding = torch.nn.EmbeddingBag(
+                num_embeddings=num_embeddings,
+                embedding_dim=embedding_dim,
+                include_last_offset=True,
+                scale_grad_by_freq=False,
+                mode="sum",
+            )
+        else:
+            float_embedding = torch.nn.Embedding(
+                num_embeddings=num_embeddings, embedding_dim=embedding_dim
+            )
+
+        if set_qconfig:
+            float_qparams_observer = PerChannelMinMaxObserver.with_args(
+                dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
+            )
+            float_embedding.qconfig = QConfig(
+                activation=default_dynamic_quant_observer, weight=float_qparams_observer
+            )
+
+        prepare_dynamic(float_embedding)
+
+        float_embedding(*inputs)
+        if is_emb_bag:
+            q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding)
+            expected_name = "QuantizedEmbeddingBag"
+        else:
+            q_embeddingbag = nnq.Embedding.from_float(float_embedding)
+            expected_name = "QuantizedEmbedding"
+
+        q_embeddingbag(*inputs)
+
+        self.assertTrue(expected_name in str(q_embeddingbag))
+
+
+class QuantizationLiteTestCase(QuantizationTestCase):
+    def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs):
+        # Creates quantized model for testing mobile script modules
+        qengine = "qnnpack"
+        with override_quantized_engine(qengine):
+            # FIXME(rec): shouldn't qconfig be passed to quantize?
+            qconfig = torch.ao.quantization.get_default_qconfig(qengine)  # noqa: F841
+            model = model_class(**kwargs)
+            model = quantize(model, test_only_eval_fn, [self.calib_data])
+
+        return model
+
+    def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor):
+        # Compares the numerical outputs for script and lite modules
+        qengine = "qnnpack"
+        with override_quantized_engine(qengine):
+            script_module = torch.jit.script(model)
+            script_module_result = script_module(input)
+
+            max_retry = 5
+            for retry in range(1, max_retry + 1):
+                # retries `max_retry` times; breaks iff succeeds else throws exception
+                try:
+                    buffer = io.BytesIO(
+                        script_module._save_to_buffer_for_lite_interpreter()
+                    )
+                    buffer.seek(0)
+                    mobile_module = _load_for_lite_interpreter(buffer)
+
+                    mobile_module_result = mobile_module(input)
+
+                    torch.testing.assert_close(
+                        script_module_result, mobile_module_result
+                    )
+                    mobile_module_forward_result = mobile_module.forward(input)
+                    torch.testing.assert_close(
+                        script_module_result, mobile_module_forward_result
+                    )
+
+                    mobile_module_run_method_result = mobile_module.run_method(
+                        "forward", input
+                    )
+                    torch.testing.assert_close(
+                        script_module_result, mobile_module_run_method_result
+                    )
+                except AssertionError as e:
+                    if retry == max_retry:
+                        raise e
+                    else:
+                        continue
+                break
+
+
+class PT2EQuantizationTestCase(QuantizationTestCase):
+    """
+    Base QuantizationTestCase for PT2 with some helper methods.
+    """
+
+    _MAP_TO_FX_TRACED_OPS = {
+        torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default,
+        torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+        torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default,
+        torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    }
+
+    def _test_quantizer(
+        self,
+        model,
+        example_inputs,
+        quantizer,
+        expected_node_occurrence,
+        expected_node_list=None,
+        check_against_fx_quant=False,
+        fx_qconfig_mapping=None,
+        export_with_dynamic_shape=False,
+        is_qat=False,
+        is_debug_mode=False,
+        training_ir_node_occurrence=None,
+    ):
+        # resetting dynamo cache
+        torch._dynamo.reset()
+        m_eager = model.eval()
+
+        # program capture
+        m = copy.deepcopy(m_eager)
+        dynamic_shapes = tuple(
+            {0: torch.export.Dim("dim")} if i == 0 else None
+            for i in range(len(example_inputs))
+        )
+        m = export(
+            m,
+            example_inputs,
+            dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
+            strict=True,
+        ).module()
+
+        if is_qat:
+            m = prepare_qat_pt2e(m, quantizer)
+        else:
+            m = prepare_pt2e(m, quantizer)
+        if is_debug_mode:
+            print("prepared model:", m)
+        # Calibrate
+        m(*example_inputs)
+        m = convert_pt2e(m)
+        if is_debug_mode:
+            print("quantized model", m)
+
+        pt2_quant_output = m(*example_inputs)
+        ns = NodeSpec
+        node_occurrence = {
+            ns.call_function(k): v for k, v in expected_node_occurrence.items()
+        }
+        if expected_node_list is None:
+            expected_node_list = []
+        node_list = [ns.call_function(n) for n in expected_node_list]
+        self.checkGraphModuleNodes(
+            m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
+        )
+        if check_against_fx_quant:
+            qconfig_mapping = fx_qconfig_mapping
+            backend_config = get_executorch_backend_config()
+            m_copy = copy.deepcopy(m_eager)
+            m_fx = prepare_fx(
+                m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
+            )
+            m_fx(*example_inputs)
+            m_fx = _convert_to_reference_decomposed_fx(
+                m_fx, backend_config=backend_config
+            )
+            m_fx = export(
+                m_fx,
+                example_inputs,
+                dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
+                strict=True,
+            ).module()
+            node_occurrence = {}
+            for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
+                if k in expected_node_occurrence:
+                    node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
+            if training_ir_node_occurrence is not None:
+                node_occurrence = {
+                    ns.call_function(k): v
+                    for k, v in training_ir_node_occurrence.items()
+                }
+            self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
+            fx_quant_output = m_fx(*example_inputs)
+            self.assertEqual(fx_quant_output, pt2_quant_output)
+        return m
+
+    def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
+        # resetting dynamo cache
+        torch._dynamo.reset()
+
+        m = export(m, example_inputs, strict=True).module()
+        if is_qat:
+            m = prepare_qat_pt2e(m, quantizer)
+        else:
+            m = prepare_pt2e(m, quantizer)
+        m(*example_inputs)
+        m = convert_pt2e(m)
+        return m
+
+    def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
+        class M(torch.nn.Module):
+            def __init__(self) -> None:
+                super().__init__()
+                self.linear = torch.nn.Linear(2, 2)
+
+            def forward(self, x):
+                return self.linear(x)
+
+        quantizer = XNNPACKQuantizer()
+        operator_config = get_symmetric_quantization_config(
+            is_per_channel=is_per_channel
+        )
+        quantizer.set_global(operator_config)
+        example_inputs = (torch.randn(2, 2),)
+        m = M().eval()
+        return self._quantize(m, quantizer, example_inputs)
+
+
+# Below are a series of toy models to use in testing quantization
+
+
+class SingleLayerLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class AnnotatedSingleLayerLinearModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+
+    def forward(self, x):
+        x = self.fc1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class SingleLayerLinearDynamicModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearAddModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = torch.add(x, 5)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class RNNDynamicModel(torch.nn.Module):
+    def __init__(self, mod_type):
+        super().__init__()
+        self.qconfig = default_dynamic_qconfig
+        if mod_type == "GRU":
+            self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
+        if mod_type == "LSTM":
+            self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.mod(x)
+        return x
+
+
+class RNNCellDynamicModel(torch.nn.Module):
+    def __init__(self, mod_type):
+        super().__init__()
+        self.qconfig = default_dynamic_qconfig
+        if mod_type == "GRUCell":
+            self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float)
+        if mod_type == "LSTMCell":
+            self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float)
+        if mod_type == "RNNReLU":
+            self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float)
+        if mod_type == "RNNTanh":
+            self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.mod(x)
+        return x
+
+
+class LSTMwithHiddenDynamicModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)
+
+    def forward(self, x, hid):
+        x, hid = self.lstm(x, hid)
+        return x, hid
+
+
+class ConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class ConvTransposeModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvModel(torch.nn.Module):
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.dequant(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvTransposeModel(torch.nn.Module):
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.dequant(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class ConvBnModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvBnModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.qconfig = default_qconfig
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.dequant(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class ConvBnReLUModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvBnReLUModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+        self.relu = nn.ReLU(inplace=True)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        x = self.dequant(x)
+        return x
+
+    def fuse_model(self):
+        # TODO: remove this check and define two fuse_modules function on this module
+        if self.training:
+            torch.ao.quantization.fuse_modules_qat(
+                self, [["conv", "bn", "relu"]], inplace=True
+            )
+        else:
+            torch.ao.quantization.fuse_modules(
+                self, [["conv", "bn", "relu"]], inplace=True
+            )
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class TwoLayerConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class TwoLayerLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearModelWithSubmodule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.subm = TwoLayerLinearModel()
+        self.fc = nn.Linear(5, 5)
+
+    def forward(self, x):
+        x = self.subm(x)
+        x = self.fc(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.subm.get_example_inputs()
+
+
+class AnnotatedTwoLayerLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float))
+        self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class ActivationsTestModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
+        self.quant = torch.ao.quantization.QuantStub()
+        self.hardswish = torch.nn.Hardswish().to(dtype=torch.float)
+        self.elu = torch.nn.ELU().to(dtype=torch.float)
+        self.dequant = torch.ao.quantization.DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.hardswish(x)
+        x = self.elu(x)
+        x = self.dequant(x)
+        return x
+
+
+class LinearReluModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+
+    def forward(self, x):
+        x = self.relu(self.fc(x))
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearReluLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearReluAddModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = torch.add(x, 5)
+        x = self.fc2(x)
+        self.relu = torch.nn.ReLU()
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearBnLeakyReluModel(torch.nn.Module):
+    def __init__(self, with_bn=True):
+        super().__init__()
+        self.linear = nn.Linear(5, 5)
+        self.bn1d = nn.BatchNorm1d(5)
+        self.leaky_relu = nn.LeakyReLU(0.01)
+        self.with_bn = with_bn
+
+    def forward(self, x):
+        x = self.linear(x)
+        if self.with_bn:
+            x = self.bn1d(x)
+        x = self.leaky_relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearTanhModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear = nn.Linear(5, 5)
+        self.tanh = nn.Tanh()
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = self.tanh(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class ConvBnAddReluModel(torch.nn.Module):
+    def __init__(
+        self,
+        with_bn=True,
+        with_relu=True,
+        left_conv=True,
+        two_conv=True,
+        use_torch_add=True,
+    ):
+        super().__init__()
+        self.conv = nn.Conv2d(5, 5, (2, 2))
+        self.conv2 = nn.Conv2d(5, 5, (2, 2))
+        self.bn = nn.BatchNorm2d(5)
+        self.relu = nn.ReLU()
+        self.with_bn = with_bn
+        self.with_relu = with_relu
+        self.two_conv = two_conv
+        self.left_conv = left_conv
+        self.use_torch_add = use_torch_add
+
+    def forward(self, x1, x2):
+        if self.two_conv:
+            if self.use_torch_add:
+                if self.with_bn:
+                    x = torch.add(self.bn(self.conv(x1)), self.conv2(x1))
+                else:
+                    x = torch.add(self.conv(x1), self.conv2(x1))
+            else:
+                if self.with_bn:
+                    x = self.bn(self.conv(x1)) + self.conv2(x1)
+                else:
+                    x = self.conv(x1) + self.conv2(x1)
+        else:
+            if self.use_torch_add:
+                if self.left_conv:
+                    if self.with_bn:
+                        x = torch.add(self.bn(self.conv(x1)), x2)
+                    else:
+                        x = torch.add(self.conv(x1), x2)
+                else:
+                    if self.with_bn:
+                        x = torch.add(x2, self.bn(self.conv(x1)))
+                    else:
+                        x = torch.add(x2, self.conv(x1))
+            else:
+                if self.left_conv:
+                    if self.with_bn:
+                        x = self.bn(self.conv(x1)) + x2
+                    else:
+                        x = self.conv(x1) + x2
+                else:
+                    if self.with_bn:
+                        x = x2 + self.bn(self.conv(x1))
+                    else:
+                        x = x2 + self.conv(x1)
+        if self.with_relu:
+            x = self.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2))
+
+
+# TODO: self.fc should be self.conv
+class ConvReluModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+
+    def forward(self, x):
+        x = self.relu(self.fc(x))
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+# TODO: self.fc should be self.conv
+class ConvReluConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+# TODO: self.fc should be self.conv
+class ConvReluAddModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = torch.add(x, 5)
+        x = self.fc2(x)
+        self.relu = torch.nn.ReLU()
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class NormalizationTestModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.quant = torch.ao.quantization.QuantStub()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.layer_norm = torch.nn.LayerNorm(8)
+        self.group_norm = torch.nn.GroupNorm(2, 8)
+        self.instance_norm1d = torch.nn.InstanceNorm1d(8)
+        self.instance_norm2d = torch.nn.InstanceNorm2d(8)
+        self.instance_norm3d = torch.nn.InstanceNorm3d(8)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc1(x)
+        x = self.layer_norm(x)
+        x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3))
+        x = self.instance_norm1d(x)
+        x = self.instance_norm2d(x.unsqueeze(-1))
+        x = self.instance_norm3d(x.unsqueeze(-1))
+        return x
+
+
+class NestedModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = TwoLayerLinearModel()
+        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class AnnotatedNestedModel(torch.nn.Module):
+    def __init__(self, qengine):
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = TwoLayerLinearModel()
+        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+        self.fc3.qconfig = default_qconfig
+        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
+        if qengine == "fbgemm":
+            self.sub2.fc1.qconfig = default_per_channel_qconfig
+        else:
+            self.sub2.fc1.qconfig = default_qconfig
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class AnnotatedSubNestedModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = QuantWrapper(TwoLayerLinearModel())
+        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+        self.fc3.qconfig = default_qconfig
+        self.sub2.qconfig = default_qconfig
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class AnnotatedCustomConfigNestedModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = TwoLayerLinearModel()
+        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+        self.fc3.qconfig = default_qconfig
+        self.sub2.qconfig = default_qconfig
+
+        custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine}
+        custom_qconfig = QConfig(
+            activation=default_observer.with_args(**custom_options),
+            weight=default_weight_observer,
+        )
+        self.sub2.fc1.qconfig = custom_qconfig
+
+        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
+        self.sub2.fc2 = QuantWrapper(self.sub2.fc2)
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class QuantSubModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = QuantWrapper(TwoLayerLinearModel())
+        self.sub2.qconfig = default_qconfig
+        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        self.fc3.qconfig = default_qconfig
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class InnerModule(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.relu1 = torch.nn.ReLU()
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+        self.relu2 = torch.nn.ReLU()
+
+    def forward(self, x):
+        return self.relu2(self.fc2(self.relu1(self.fc1(x))))
+
+    def fuse_modules(self):
+        fusable_layers = []
+        named_children = list(self.named_children())
+        for idx, (current_name, layer) in enumerate(named_children):
+            if isinstance(layer, torch.nn.Linear):
+                if idx >= len(named_children) - 1:
+                    break
+                if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
+                    fusable_layers.append([current_name, named_children[idx + 1][0]])
+        # TODO: remove this check and define two fuse_modules function on this module
+        if self.training:
+            torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
+        else:
+            torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
+
+
+class FunctionalLinear(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.weight = torch.rand((5, 5))
+        self.bias = torch.zeros(5)
+
+    def forward(self, x):
+        return F.linear(x, self.weight, self.bias)
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class SingleLayerFunctionalLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class TwoLayerFunctionalLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+        self.linear2 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        x = self.linear2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class FunctionalLinearAddModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+        self.linear2 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        x = torch.add(x, 5)
+        x = self.linear2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class FunctionalLinearReluModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = F.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear.get_example_inputs()
+
+
+class FunctionalLinearReluLinearModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+        self.relu = nn.ReLU()
+        self.linear2 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        x = self.relu(x)
+        x = self.linear2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class FunctionalConv2d(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.weight = torch.rand(3, 3, 3, 3)
+        self.bias = torch.rand(3)
+        self.stride = (1, 1)
+        self.padding = (0, 0)
+        self.dilation = (1, 1)
+        self.groups = 1
+
+    def forward(self, x):
+        return F.conv2d(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class SingleLayerFunctionalConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv1.get_example_inputs()
+
+
+class TwoLayerFunctionalConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = FunctionalConv2d()
+        self.conv2 = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv1.get_example_inputs()
+
+
+class FunctionalConvReluModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = F.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv.get_example_inputs()
+
+
+class FunctionalConvReluConvModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = FunctionalConv2d()
+        self.relu = nn.ReLU()
+        self.conv2 = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv1.get_example_inputs()
+
+
+class SkipQuantModel(torch.nn.Module):
+    r"""We can skip quantization by explicitly
+    setting qconfig of a submodule to None
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub = InnerModule()
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        return self.fc(self.sub(x))
+
+    def fuse_modules(self):
+        self.sub.fuse_modules()
+
+
+class AnnotatedSkipQuantModel(torch.nn.Module):
+    r"""We can skip quantization by explicitly
+    setting qconfig of a submodule to None
+    """
+
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.sub = QuantWrapper(InnerModule())
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        # don't quantize this fc
+        self.fc.qconfig = None
+
+    def forward(self, x):
+        return self.fc(self.sub(x))
+
+    def fuse_modules(self):
+        self.sub.module.fuse_modules()
+
+
+class QuantStubModel(torch.nn.Module):
+    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`"""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc(x)
+        return self.dequant(x)
+
+
+class ManualLinearQATModel(torch.nn.Module):
+    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`"""
+
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return self.dequant(x)
+
+
+class ManualDropoutQATModel(torch.nn.Module):
+    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`"""
+
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
+        self.dropout = torch.nn.Dropout(0.5)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc1(x)
+        x = self.dropout(x)
+        return self.dequant(x)
+
+
+class ManualLinearDynamicQATModel(torch.nn.Module):
+    r"""A Module that uses a dynamic QAT by default."""
+
+    def __init__(self, qconfig=None):
+        super().__init__()
+        self.qconfig = qconfig or default_dynamic_qat_qconfig
+        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+
+class ManualConvLinearQATModel(torch.nn.Module):
+    r"""A module with manually inserted `QuantStub` and `DeQuantStub`
+    and contains both linear and conv modules
+    """
+
+    def __init__(self, qconfig=None):
+        super().__init__()
+        self.qconfig = (
+            qconfig
+            if qconfig
+            else torch.ao.quantization.get_default_qat_qconfig("qnnpack")
+        )
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float)
+        self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = x.view(-1, 64).contiguous()
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return self.dequant(x)
+
+
+class ManualConvLinearSymmQATModel(ManualConvLinearQATModel):
+    r"""Same as ManualConvLinearQATModule but with Symmetric Quantization.
+    Supported only with qnnpack.
+    """
+
+    def __init__(self) -> None:
+        super().__init__(default_symmetric_qnnpack_qat_qconfig)
+
+
+class ManualEmbeddingBagLinear(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum")
+        self.emb.qconfig = default_embedding_qat_qconfig
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
+        self.qconfig = get_default_qat_qconfig("qnnpack")
+
+    def forward(
+        self,
+        input: torch.Tensor,
+        offsets: Optional[torch.Tensor] = None,
+        per_sample_weights: Optional[torch.Tensor] = None,
+    ):
+        x = self.emb(input, offsets, per_sample_weights)
+        x = self.quant(x)
+        x = self.linear(x)
+        return self.dequant(x)
+
+
+class DeFusedEmbeddingBagLinear(nn.Module):
+    r"""A module to simulate QAT embedding bag with a linear layer,
+    this module uses a separate embedding and bagging op, similar
+    to that which is described in the EmbeddingBag documentation.
+
+    https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12)
+        self.emb.qconfig = default_embedding_qat_qconfig
+        self.bagging_op = torch.sum
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
+        self.qconfig = get_default_qat_qconfig("qnnpack")
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        x = self.bagging_op(self.emb(input), dim=1)
+        x = self.quant(x)
+        x = self.linear(x)
+        return self.dequant(x)
+
+
+class SubModelForFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
+        self.bn = nn.BatchNorm2d(2).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x
+
+
+class SubModelWithoutFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
+        self.relu = nn.ReLU(inplace=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        return self.relu(self.conv(x))
+
+
+class ModelForFusion(nn.Module):
+    def __init__(self, qconfig):
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float)
+        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
+        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        self.sub1 = SubModelForFusion()
+        self.sub2 = SubModelWithoutFusion()
+        self.fc = nn.Linear(36, 10).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.qconfig = qconfig
+        self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float)
+        self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
+        self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
+        self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
+        self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
+        self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        # don't quantize sub2
+        self.sub2.qconfig = None
+        self.fc.qconfig = None
+
+    def forward(self, x):
+        x = x.squeeze(2)
+        x = self.quant(x)
+        x = self.conv3(x)
+        x = self.bn3(x)
+        x = self.relu4(x)
+        x = x.unsqueeze(2)
+        y = x.unsqueeze(2)
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu1(x)
+        x = self.sub1(x)
+        x = self.dequant(x)
+        x = self.sub2(x)
+        x = x.reshape(-1, 36).contiguous()
+        x = self.fc(x)
+        y = self.conv2(y)
+        y = self.relu2(y)
+        y = self.bn2(y)
+        y = self.relu3(y)
+        y = self.dequant(y)
+        return x
+
+
+class ConvBNReLU(nn.Sequential):
+    def __init__(self) -> None:
+        super().__init__(
+            nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False)
+        )
+
+
+class ModelWithSequentialFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 3, 1)
+        self.relu1 = nn.ReLU(inplace=False)
+        layers = [ConvBNReLU() for _ in range(3)]
+        self.features = nn.Sequential(*layers)
+        head = [nn.Linear(300, 10), nn.ReLU(inplace=False)]
+        self.classifier = nn.Sequential(*head)
+        self.seq = nn.Sequential()
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv1(x)
+        x = self.relu1(x)
+        x = self.features(x)
+        x = torch.reshape(x, (-1, 3 * 10 * 10))
+        x = self.classifier(x)
+        x = self.seq(x)
+        x = self.dequant(x)
+        return x
+
+
+class ModelForFusionWithBias(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float)
+        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
+        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float)
+        self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu1(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.dequant(x)
+        return x
+
+
+class ModelForLinearBNFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = nn.Linear(20, 10)
+        self.bn = nn.BatchNorm1d(10)
+        nn.init.uniform_(self.bn.weight)
+        nn.init.uniform_(self.bn.bias)
+
+    def forward(self, x):
+        return self.bn(self.fc(x))
+
+
+class DummyObserver(torch.nn.Module):
+    def calculate_qparams(self):
+        return 1.0, 0
+
+    def forward(self, x):
+        return x
+
+
+class ModelForConvTransposeBNFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.ConvTranspose1d(3, 3, 1)
+        self.bn1 = nn.BatchNorm1d(3)
+        self.conv2 = nn.ConvTranspose2d(3, 3, 1)
+        self.bn2 = nn.BatchNorm2d(3)
+        self.conv3 = nn.ConvTranspose3d(3, 3, 1)
+        self.bn3 = nn.BatchNorm3d(3)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = x.unsqueeze(2)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = x.unsqueeze(2)
+        x = self.conv3(x)
+        x = self.bn3(x)
+        return x
+
+
+class ModelWithFunctionals(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.mycat = nnq.FloatFunctional()
+        self.myadd = nnq.FloatFunctional()
+        self.myadd_relu = nnq.FloatFunctional()
+        self.mymatmul = nnq.FloatFunctional()
+        # Tracing doesn't work yet for c10 ops with scalar inputs
+        # https://github.com/pytorch/pytorch/issues/27097
+        # self.my_scalar_add = nnq.FloatFunctional()
+        # self.my_scalar_mul = nnq.FloatFunctional()
+
+    def forward(self, x):
+        y = self.mycat.cat([x, x, x])
+        z = self.myadd.add(y, y)
+        w = self.myadd_relu.add_relu(z, z)
+        u = self.mymatmul.matmul(w, w.T)
+        # Tracing doesn't work yet for c10 ops with scalar inputs
+        # https://github.com/pytorch/pytorch/issues/27097
+        # w = self.my_scalar_add.add_scalar(w, -0.5)
+        # w = self.my_scalar_mul.mul_scalar(w, 0.5)
+        return u
+
+
+class ResNetBase(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        norm_layer = nn.BatchNorm2d
+        inplanes = 3
+        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.bn1 = norm_layer(inplanes)
+        self.relu1 = nn.ReLU()
+        self.relu2 = nn.ReLU()
+        self.downsample = torch.nn.Identity()
+        self.myop = nn.quantized.FloatFunctional()
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = torch.nn.Linear(inplanes, 1)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+        identity = self.downsample(x)
+        out = self.myop.add(out, identity)
+        out = self.relu2(out)
+        out = self.avgpool(out)
+        out = torch.flatten(out, 1)
+        out = self.fc(out)
+        return out
+
+    def fuse_model(self):
+        # TODO: remove this check and define two fuse_model function on this module
+        if self.training:
+            torch.ao.quantization.fuse_modules_qat(
+                self, [["conv1", "bn1", "relu1"]], inplace=True
+            )
+        else:
+            torch.ao.quantization.fuse_modules(
+                self, [["conv1", "bn1", "relu1"]], inplace=True
+            )
+
+
+class ModelMultipleOps(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        norm_layer = nn.BatchNorm2d
+        inplanes = 3
+        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.bn1 = norm_layer(inplanes)
+        self.relu1 = nn.ReLU()
+        self.relu2 = nn.ReLU()
+        self.downsample = torch.nn.Identity()
+        self.skip_add = nn.quantized.FloatFunctional()
+        self.cat = nn.quantized.FloatFunctional()
+        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
+        self.fc = nn.Linear(12, 6)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+        identity = self.downsample(x)
+        out = self.skip_add.add(out, identity)
+        out = self.relu2(out)
+        out = self.avgpool(out)
+        out = self.conv2(out)
+        out = torch.nn.functional.max_pool2d(out, 2, 2)
+        out = self.cat.cat([out, out])
+        out = out.reshape(-1, 3 * 2 * 2)
+        out = self.fc(out)
+        return out
+
+
+# Model to ensure consistency of fake quant with true quant
+# Average pooling and mean operations are not modelled
+# accurately with fake-quant so this model does not
+# contain those operations
+class ModelMultipleOpsNoAvgPool(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        norm_layer = nn.BatchNorm2d
+        inplanes = 3
+        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.bn1 = norm_layer(inplanes)
+        self.relu1 = nn.ReLU()
+        self.relu2 = nn.ReLU()
+        self.skip_add = nn.quantized.FloatFunctional()
+        self.cat = nn.quantized.FloatFunctional()
+        self.maxpool = nn.MaxPool2d((4, 4))
+        self.fc = nn.Linear(12, 6)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+        skip = self.conv2(x)
+        out = self.skip_add.add(out, skip)
+        out = self.relu2(out)
+        out = self.maxpool(out)
+        out = self.conv2(out)
+        out = torch.nn.functional.max_pool2d(out, 2, 2)
+        out = self.cat.cat([out, out])
+        out = out.reshape(-1, 3 * 2 * 2)
+        out = self.fc(out)
+        return out
+
+
+class EmbeddingBagModule(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = torch.nn.EmbeddingBag(
+            num_embeddings=10,
+            embedding_dim=12,
+            include_last_offset=True,
+            scale_grad_by_freq=False,
+            mode="sum",
+        )
+
+    def forward(self, indices, offsets, per_sample_weights):
+        return self.emb(indices, offsets, per_sample_weights)
+
+
+class EmbeddingModule(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
+
+    def forward(self, indices):
+        return self.emb(indices)
+
+
+class EmbeddingWithStaticLinear(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12)
+        self.fc = torch.nn.Linear(4, 2)
+        self.emb.qconfig = float_qparams_weight_only_qconfig
+        self.qconfig = default_qconfig
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, indices, offsets, linear_in):
+        emb = self.emb(indices, offsets)
+        q_x = self.quant(linear_in)
+        fc = self.fc(q_x)
+        fc = self.dequant(fc)
+        features = torch.cat([fc] + [emb], dim=1)
+        return features
+
+
+class DenseTopMLP(nn.Module):
+    def __init__(
+        self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out
+    ) -> None:
+        super().__init__()
+
+        self.dense_mlp = nn.Sequential(
+            nn.Linear(dense_dim, dense_out),
+        )
+        self.top_mlp = nn.Sequential(
+            nn.Linear(dense_out + embedding_dim, top_out_in),
+            nn.Linear(top_out_in, top_out_out),
+        )
+
+    def forward(
+        self,
+        sparse_feature: torch.Tensor,
+        dense: torch.Tensor,
+    ) -> torch.Tensor:
+        dense_feature = self.dense_mlp(dense)
+        features = torch.cat([dense_feature] + [sparse_feature], dim=1)
+
+        out = self.top_mlp(features)
+        return out
+
+
+# thin wrapper around embedding bag, because tracing inside nn.Embedding
+# bag is not supported at the moment and this is top level
+class EmbBagWrapper(nn.Module):
+    def __init__(self, num_embeddings, embedding_dim):
+        super().__init__()
+        self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
+
+    def forward(self, indices, offsets):
+        return self.emb_bag(indices, offsets)
+
+
+class SparseNNModel(nn.Module):
+    _NUM_EMBEDDINGS = 10
+    _EMBEDDING_DIM = 5
+    _DENSE_DIM = 4
+    _DENSE_OUTPUT = 2
+    _TOP_OUT_IN = 2
+    _TOP_OUT_OUT = 2
+    _TOP_MLP_DIM = 1
+
+    def __init__(self) -> None:
+        super().__init__()
+
+        self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM)
+        self.dense_top = DenseTopMLP(
+            self._DENSE_DIM,
+            self._DENSE_OUTPUT,
+            self._EMBEDDING_DIM,
+            self._TOP_OUT_IN,
+            self._TOP_OUT_OUT,
+        )
+
+    def forward(
+        self,
+        sparse_indices: torch.Tensor,
+        sparse_offsets: torch.Tensor,
+        dense: torch.Tensor,
+    ) -> torch.Tensor:
+        sparse_feature = self.model_sparse(sparse_indices, sparse_offsets)
+        out = self.dense_top(sparse_feature, dense)
+
+        return out
+
+
+class TestHelperModules:
+    class ControlFlow(torch.nn.Module):
+        def forward(
+            self,
+            xs: torch.Tensor,
+            pred1: torch.Tensor,
+            pred2: torch.Tensor,
+            y: torch.Tensor,
+        ) -> torch.Tensor:
+            def true_nested(y: torch.Tensor) -> torch.Tensor:
+                y = y + y
+                y = torch.mm(y, y)
+                return y
+
+            def false_nested(y: torch.Tensor) -> torch.Tensor:
+                return torch.mm(y, y)
+
+            def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
+                z = control_flow.cond(pred2, true_nested, false_nested, [x])
+                return x + z
+
+            def false_fn(x: torch.Tensor, _) -> torch.Tensor:
+                return x.cos()
+
+            def map_fn(
+                x: torch.Tensor,
+                pred1: torch.Tensor,
+                pred2: torch.Tensor,
+                y: torch.Tensor,
+            ) -> torch.Tensor:
+                x = x.cos()
+                y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
+                x = x + y
+                return x.sin()
+
+            y = torch.mm(y, y)
+            return control_flow.map(map_fn, xs, pred1, pred2, y)
+
+        def example_inputs(self):
+            return (
+                torch.ones(2, 2),
+                torch.tensor([False]),
+                torch.tensor([False]),
+                torch.ones(2, 2),
+            )
+
+    class Conv2dPropAnnotaton(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 3, 3)
+            self.linear = torch.nn.Linear(3, 3)
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = x.view(-1, 3)
+            x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
+            x = self.linear(x)
+            return x
+
+    class Conv2dWithObsSharingOps(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 3, 3)
+            self.hardtanh = torch.nn.Hardtanh()
+            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.adaptive_avg_pool2d(x)
+            x = self.hardtanh(x)
+            x = torch.mean(x)
+            return x
+
+    class Conv2dWithTwoLinearPermute(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 16, 3)
+            self.linear1 = torch.nn.Linear(16, 8, bias=False)
+            self.linear2 = torch.nn.Linear(8, 8)
+
+        def forward(self, x):
+            conv_out = self.conv(x)
+            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
+            return self.linear2(self.linear1(permute_out))
+
+    class Conv2dWithTwoLinear(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 16, 3)
+            self.linear1 = torch.nn.Linear(64, 8, bias=False)
+            self.linear2 = torch.nn.Linear(8, 8)
+
+        def forward(self, x):
+            conv_out = self.conv(x)
+            reshape_out = torch.reshape(conv_out, (2, 64))
+            return self.linear2(self.linear1(reshape_out))
+
+    class ConvLinearWPermute(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 8, 3)
+            self.linear1 = torch.nn.Linear(8, 8)
+
+        def forward(self, x):
+            conv_out = self.conv(x)
+            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
+            return self.linear1(permute_out)
+
+    class TwoLinearModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.linear1 = torch.nn.Linear(8, 16, bias=False)
+            self.linear2 = torch.nn.Linear(16, 8)
+
+        def forward(self, x):
+            return self.linear2(self.linear1(x))
+
+        def example_inputs(self):
+            return (torch.randn(2, 8),)
+
+    class ConvMaxPool2d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(2, 2, 1)
+            self.pool = torch.nn.MaxPool2d(1, 1)
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.pool(x)
+            return x
+
+    class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 3, 3)
+            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.adaptive_avg_pool2d(x)
+            return x
+
+    class ConvWithBNRelu(torch.nn.Module):
+        def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
+            super().__init__()
+            convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
+            bns = {
+                1: torch.nn.BatchNorm1d,
+                2: torch.nn.BatchNorm2d,
+                3: torch.nn.BatchNorm3d,
+            }
+            self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
+
+            if bn:
+                self.bn = bns[dim](3)
+            else:
+                self.bn = torch.nn.Identity()
+            if relu:
+                self.relu = torch.nn.ReLU()
+            else:
+                self.relu = torch.nn.Identity()
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.bn(x)
+            return self.relu(x)
+
+    class ConvTWithBNRelu(torch.nn.Module):
+        def __init__(self, relu, dim=2, bn=True, bias=True):
+            super().__init__()
+            convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d}
+            bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
+            self.convt = convts[dim](3, 3, 3, bias=bias)
+
+            if bn:
+                self.bn = bns[dim](3)
+            else:
+                self.bn = torch.nn.Identity()
+            if relu:
+                self.relu = torch.nn.ReLU()
+            else:
+                self.relu = torch.nn.Identity()
+
+        def forward(self, x):
+            x = self.convt(x)
+            x = self.bn(x)
+            return self.relu(x)
+
+    class Conv2dThenConv1d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1d = torch.nn.Conv1d(3, 3, 3)
+            self.conv2d = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x):
+            x = self.conv2d(x)
+            x = x.squeeze(0)
+            x = self.conv1d(x)
+            return x
+
+        def example_inputs(self):
+            return (torch.randn(1, 3, 5, 5),)
+
+    class Conv2dWithCat(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1 = torch.nn.Conv2d(3, 3, 3)
+            self.conv2 = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x, y):
+            x = self.conv1(x)
+            y = self.conv2(y)
+            z = torch.cat([x, y], dim=1)
+            return z
+
+    class Conv2dWithTwoCat(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1 = torch.nn.Conv2d(3, 3, 3)
+            self.conv2 = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x1, x2, x3, x4):
+            x1 = self.conv1(x1)
+            x2 = self.conv2(x2)
+            y = torch.cat([x1, x2], dim=1)
+            z = x3 + x4
+            w = torch.cat([z, y])
+            return w
+
+    class Conv2dWithSplit(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1 = torch.nn.Conv2d(3, 3, 3)
+            self.conv2 = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x):
+            x = self.conv1(x)
+            # use split so we get a list of Tensors
+            x1, x2 = torch.split(x, 2, dim=1)
+            y = torch.cat([x1, x2], dim=1)
+            return y
+
+        def example_inputs(self):
+            return (torch.randn(1, 3, 16, 16),)
+
+    class ThreeAdd(torch.nn.Module):
+        def forward(self, x1, x2, x3, x4):
+            y = x1 + x2
+            z = x3 + x4
+            w = y + z
+            return w
+
+    class EmbeddingModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
+
+        def forward(self, indices):
+            return self.emb(indices)
+
+    class EmbeddingConvLinearModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
+            self.conv = torch.nn.Conv2d(8, 16, (1, 3))
+            self.linear = torch.nn.Linear(16, 8)
+
+        def forward(self, indices):
+            embeddings = self.emb(indices)
+            embeddings = torch.unsqueeze(embeddings, dim=0)
+            embeddings = torch.permute(embeddings, (0, 3, 1, 2))
+            conv_out = self.conv(embeddings)
+            conv_out = torch.permute(conv_out, (0, 2, 3, 1))
+            conv_out = torch.squeeze(conv_out, dim=0)
+            return self.linear(conv_out)
+
+    class AddInplaceAdd(torch.nn.Module):
+        def forward(self, x, y):
+            x = x + y
+            x += y
+            return x
+
+    class MulInplaceMul(torch.nn.Module):
+        def forward(self, x, y):
+            x = x * y
+            x *= y
+            return x
+
+    class AddMulScalar(torch.nn.Module):
+        def forward(self, x):
+            x = x + 3
+            x = x * 3
+            x += 3
+            x *= 3
+            return x
+
+    class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True)
+            self.linear = torch.nn.Linear(3, 8, bias=False)
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, x):
+            x = self.conv_bn_relu(x)
+            permute_out = torch.permute(x, (0, 2, 3, 1))
+            linear_out = self.linear(permute_out)
+            return linear_out
+
+    class GroupwiseConv2d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(4, 4, 3, groups=2)
+
+        def forward(self, x):
+            return self.conv(x)
+
+        def example_inputs(self):
+            return (torch.randn(2, 4, 10, 10),)
+
+    class LinearReluModel(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, x):
+            x = self.relu(self.fc(x))
+            return x
+
+
+def _generate_qdq_quantized_model(
+    mod, inputs, is_qat=False, is_dynamic=False, quantizer=None
+):
+    def get_default_quantizer(is_qat, is_dynamic, inputs):
+        has_xpu = any(
+            isinstance(input, torch.Tensor) and input.device.type == "xpu"
+            for input in inputs
+        )
+        if has_xpu:
+            quantizer = XPUInductorQuantizer()
+            assert (not is_qat) and (
+                not is_dynamic
+            ), "QAT and dynamic quantization is not supported at XPU backend currently"
+            quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config())
+        else:
+            quantizer = X86InductorQuantizer()
+            quantizer.set_global(
+                xiq.get_default_x86_inductor_quantization_config(
+                    is_qat=is_qat, is_dynamic=is_dynamic
+                )
+            )
+        return quantizer
+
+    maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
+    with maybe_no_grad:
+        export_model = export(mod, inputs, strict=True).module(check_guards=False)
+        quantizer = (
+            quantizer
+            if quantizer
+            else get_default_quantizer(is_qat, is_dynamic, inputs)
+        )
+        prepare_model = (
+            prepare_qat_pt2e(export_model, quantizer)
+            if is_qat
+            else prepare_pt2e(export_model, quantizer)
+        )
+        prepare_model(*inputs)
+        torch.ao.quantization.move_exported_model_to_eval(prepare_model)
+        convert_model = convert_pt2e(prepare_model)
+        return convert_model
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bd57fa976ebc671e0184cc1a32128a3aed5b6bf
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py
@@ -0,0 +1,675 @@
+# mypy: ignore-errors
+
+r"""Importing this file includes common utility methods for checking quantized
+tensors and modules.
+"""
+import numpy as np
+import torch
+from torch import Tensor
+from contextlib import contextmanager
+from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS
+
+supported_qengines = torch.backends.quantized.supported_engines
+# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
+# QNNPACK is not supported on PPC
+if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]):
+    supported_qengines.remove('qnnpack')
+
+def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
+                       output_padding=0):
+    """Computes the output shape given convolution parameters."""
+    return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
+                     * (dilation - 1)) / stride) + 2 * output_padding + 1
+
+# Quantization references
+def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
+    """Quantizes a numpy array."""
+    if qmin is None:
+        qmin = np.iinfo(dtype).min
+    if qmax is None:
+        qmax = np.iinfo(dtype).max
+    qx = np.round(x / scale + zero_point).astype(np.int64)
+    qx = np.clip(qx, qmin, qmax)
+    qx = qx.astype(dtype)
+    return qx
+
+
+def _dequantize(qx, scale, zero_point):
+    """Dequantizes a numpy array."""
+    x = (qx.astype(float) - zero_point) * scale
+    return x
+
+
+def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
+    """Requantizes a numpy array, i.e., intermediate int32 or int16 values are
+    converted back to given type"""
+    qx = (x * multiplier).round() + zero_point
+    qx = np.clip(qx, qmin, qmax).astype(qtype)
+    return qx
+
+def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
+    """Calculate the dynamic quantization parameters (scale, zero_point)
+    according to the min and max element of the tensor"""
+    assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
+    if qscheme == torch.per_tensor_symmetric:
+        assert dtype == torch.qint8
+    if isinstance(X, torch.Tensor):
+        X = X.numpy()
+    if dtype == torch.qint8:
+        if reduce_range:
+            qmin, qmax = -64, 63
+        else:
+            qmin, qmax = -128, 127
+    else:  # dtype == torch.quint8
+        if reduce_range:
+            qmin, qmax = 0, 127
+        else:
+            qmin, qmax = 0, 255
+    min_val = X.min()
+    max_val = X.max()
+    is_symmetric = (qscheme == torch.per_tensor_symmetric)
+    if min_val == max_val:
+        scale = 1.0
+        zero_point = 0
+    else:
+        if is_symmetric:
+            max_val = max(max_val, -min_val)
+            min_val = -max_val
+            scale = (max_val - min_val) / (qmax - qmin)
+            scale = max(scale, np.finfo(np.float32).eps)
+            zero_point = 0
+        else:
+            max_val = max(max_val, 0.0)
+            min_val = min(min_val, 0.0)
+            scale = (max_val - min_val) / (qmax - qmin)
+            scale = max(scale, np.finfo(np.float32).eps)
+            zero_point = qmin - round(min_val / scale)
+            zero_point = max(qmin, zero_point)
+            zero_point = min(qmax, zero_point)
+    return [float(scale), int(zero_point)]
+
+def _calculate_dynamic_per_channel_qparams(X, dtype):
+    """Calculate the dynamic quantization parameters (scale, zero_point)
+    according to the min and max element of the tensor"""
+    if isinstance(X, torch.Tensor):
+        X = X.numpy()
+    qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
+    n_levels = qmax - qmin
+    scale = np.zeros(X.shape[0], dtype=np.float64)
+    zero_point = np.zeros(X.shape[0], dtype=np.int64)
+    for i in range(zero_point.shape[0]):
+        min_val = X.min()
+        max_val = X.max()
+        if min_val == max_val:
+            scale[i] = 1.0
+            zero_point[i] = 0
+        else:
+            max_val = max(max_val, 0.0)
+            min_val = min(min_val, 0.0)
+            scale[i] = (max_val - min_val) / n_levels
+            scale[i] = max(scale[i], np.finfo(np.float32).eps)
+            zero_point[i] = qmin - round(min_val / scale[i])
+            zero_point[i] = max(qmin, zero_point[i])
+            zero_point[i] = min(qmax, zero_point[i])
+
+    return scale, zero_point
+
+def _snr(x, x_hat):
+    """Calculates the signal to noise ratio and returns the signal and noise
+    power, as well as the SNR in dB.
+    If the input is a list/tuple this function is called recursively on each
+    element. The result will have the same nested structure as the inputs.
+
+    Args:
+        x, x_hat: Either a tensor or a nested list/tuple of tensors.
+    Returns:
+        signal, noise, SNR(in dB): Either floats or a nested list of floats
+    """
+    if isinstance(x, (list, tuple)):
+        assert len(x) == len(x_hat)
+        res = [_snr(x[idx], x_hat[idx]) for idx in range(len(x))]
+        return res
+    if x_hat.is_quantized:
+        x_hat = x_hat.dequantize()
+    if x.is_quantized:
+        x = x.dequantize()
+    noise = (x - x_hat).norm()
+    if noise == 0:
+        return 0.0, float('inf'), float('inf')
+    signal = x.norm()
+    snr = signal / noise
+    snr_db = 20 * snr.log10()
+    return signal, noise, snr_db
+
+@contextmanager
+def override_quantized_engine(qengine):
+    previous = torch.backends.quantized.engine
+    torch.backends.quantized.engine = qengine
+    try:
+        yield
+    finally:
+        torch.backends.quantized.engine = previous
+
+@contextmanager
+def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
+    try:
+        if qengine_is_qnnpack:
+            torch._C._set_default_mobile_cpu_allocator()
+        yield
+    finally:
+        if qengine_is_qnnpack:
+            torch._C._unset_default_mobile_cpu_allocator()
+
+# TODO: Update all quantization tests to use this decorator.
+# Currently for some of the tests it seems to have inconsistent params
+# for fbgemm vs qnnpack.
+def override_qengines(qfunction):
+    def test_fn(*args, **kwargs):
+        for qengine in supported_qengines:
+            with override_quantized_engine(qengine):
+                # qfunction should not return anything.
+                qfunction(*args, **kwargs)
+    return test_fn
+
+def qengine_is_fbgemm():
+    return torch.backends.quantized.engine == 'fbgemm'
+def qengine_is_qnnpack():
+    return torch.backends.quantized.engine == 'qnnpack'
+def qengine_is_onednn():
+    return torch.backends.quantized.engine == 'onednn'
+def qengine_is_x86():
+    return torch.backends.quantized.engine == 'x86'
+
+# Helper function used to simulate per-channel fake-quant against any axis
+def _permute_to_axis_zero(X, axis):
+    new_axis_list = list(range(X.dim()))
+    new_axis_list[axis] = 0
+    new_axis_list[0] = axis
+    y = X.permute(tuple(new_axis_list))
+    return y, new_axis_list
+
+# Reference method for fake quantize
+# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
+def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
+    dtype = X.dtype
+    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
+    res = torch.zeros_like(X)
+
+    for i in range(X.size()[0]):
+        res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
+                  per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
+
+    out = res.permute(tuple(permute_axis_list))
+    return out.to(dtype)
+
+# Reference method for the gradient of the fake quantize operator
+# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
+def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
+    dtype = X.dtype
+    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
+    Xq = torch.zeros_like(X)
+    for i in range(X.size()[0]):
+        Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
+    Xq = Xq.permute(tuple(permute_axis_list))
+    mask = (Xq >= quant_min) * (Xq <= quant_max)
+    res = torch.zeros_like(dY)
+    res[mask] = dY[mask]
+    return res.to(dtype)
+
+def to_tensor(X, device):
+    if not isinstance(X, torch.Tensor):
+        X = torch.tensor(X)
+    else:
+        X = X.detach().clone()
+    return X.to(device=torch.device(device), dtype=torch.float32)
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
+def _n_ones(n: int) -> int:
+    return (1 << n) - 1
+
+EBITS_F32, MBITS_F32 = 8, 23
+F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
+def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert FP32 numbers to sub-byte floating point numbers with the given
+    number of exponent and mantissa bits.
+
+    Input: torch.Tensor of dtype torch.float
+    Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+
+    Note: there are no special values (NaN, inf) support in this code. Values
+    outside the representable range of Floatx after rounding are clamped to the
+    maximum Floatx magnitude (sign is preserved).
+
+    Code below is an adaptation of https://fburl.com/code/ciwofcg4
+
+    Background 1: last answer in https://stackoverflow.com/q/8981913
+    Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
+    """
+    assert x.dtype == torch.float
+    assert 1 + ebits + mbits <= 8
+
+    # calculate constants
+    exp_bias = _n_ones(ebits - 1)
+    max_int = _n_ones(ebits + mbits)
+    sign_mask = 1 << (ebits + mbits)
+
+    # TODO document this better
+    magic_adder = _n_ones(MBITS_F32 - mbits - 1)
+
+    # all E bits and M bits are 1s
+    max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
+
+    # E bits = 1, M bits = 0
+    min_normal = 2 ** (1 - exp_bias)
+
+    denorm_exp = (
+        # exp bias conversion between formats
+        (F32_EXP_BIAS - exp_bias)
+        # mantissa length difference between formats
+        + (MBITS_F32 - mbits)
+        # add one to encoded exponent for denormalized numbers
+        + 1
+    )
+    denorm_mask_int = denorm_exp << MBITS_F32
+
+    # reinterpret int32 as float32
+    denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
+        torch.float32
+    )
+
+    # save the sign
+    # Note that we have torch.uint32, but some ops like cpu bit shifts
+    # do not work on it. So, we stay in int32.
+    x = x.view(torch.int32)
+    sign = x & 0x80000000
+
+    # set everything to positive, will add sign back at the end
+    x = x ^ sign
+
+    # TODO: can the branch floating point comparisons below be done without
+    # converting to float? probably but need to verify
+    x = x.view(torch.float)
+
+    # rewrite saturate/denorm/norm branches without explicit data dependent
+    # control flow, to be more compiler friendly
+    saturate_mask = x >= max_normal
+    denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
+    normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
+
+    #
+    # branch 1: saturate to max val - handled later in the code which combines
+    #   the branches
+    #
+
+    #
+    # branch 2: to conversion to denormal as well as rounding up to normal
+    #
+    denormal_x = x + denorm_mask_float
+    denormal_x = denormal_x.view(torch.int32)
+    denormal_x -= denorm_mask_int
+    denormal_x = denormal_x.to(torch.uint8)
+
+    #
+    # branch 3: stay in normal range, adjust the exponent and round
+    #
+    normal_x = x.view(torch.int32)
+    # resulting mantissa is odd
+    mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
+    # update exponent, rounding bias part 1
+    val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
+    normal_x += val_to_add
+    # rounding bias part 2
+    normal_x += mant_odd
+    # take the bits!
+    normal_x = normal_x >> (MBITS_F32 - mbits)
+    normal_x = normal_x.to(torch.uint8)
+
+    #
+    # combine the branches
+    #
+    x = torch.full_like(x, max_int, dtype=torch.uint8)
+    x = torch.where(denormal_mask, denormal_x, x)
+    x = torch.where(normal_mask, normal_x, x)
+
+    # add sign back
+    sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
+    sign_lp = sign_lp.to(torch.uint8)
+    # Right shift of a negative signed integer can fill the least significant
+    # bits with either 1s or 0s, depending on the implementation. Since PyTorch
+    # doesn't have an uint32 dtype, we mask out these bits to get just the
+    # f4 sign bit
+    sign_lp = sign_lp & sign_mask
+    x = x | sign_lp
+
+    return x.to(torch.uint8)
+
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/29488018d99af7f7339f06353c6b5bbeae8a1493/torchao/prototype/custom_fp_utils.py#L147
+def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert sub-byte floating point numbers with the given number of exponent
+    and mantissa bits to FP32.
+
+    Input: torch.Tensor of dtype uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+    Output: torch.Tensor of dtype fp32 with the dequantized value
+    """
+    assert x.dtype == torch.uint8
+    assert 1 + ebits + mbits <= 8
+
+    sign_mask = 1 << (ebits + mbits)
+    exp_bias = _n_ones(ebits - 1)
+    mantissa_mask = _n_ones(mbits)
+
+    # save the sign
+    sign_lp = x & sign_mask
+
+    # set everything to positive, will add sign back at the end
+    x_pos = x ^ sign_lp
+
+    #
+    # 1. Calculate zero mask
+    #
+    zero_mask = x_pos == 0
+
+    #
+    # 2. Calculate the denormal path mask
+    #
+    denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
+
+    #
+    # 3. Calculate the normal path
+    #
+
+    # calculate the new exponent and shift it to bits 2:9 of the result
+    exp_biased_lp = x_pos >> mbits
+    exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
+    exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
+
+    # shift the mantissa to bits 10:32 of the result
+    mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
+    mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
+    result = exp_biased_f32 | mantissa_f32
+
+    #
+    # 4. Add the zero and denormal casts to the already casted normal path
+    #
+    result[zero_mask] = 0
+
+    denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
+
+    # fast path.
+    # without this, performance for FP4_E2M1 is slower by 2x
+    if mbits == 1:
+        result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
+
+    else:
+        # iterate over all possible values of mantissa
+        # i=0, j=1
+        # i=1, j=10,11
+        # i=2, j=100,101,110,111
+        # and so on
+        for i in range(mbits):
+            for mantissa_cmp in range(1 << i, 1 << (i + 1)):
+                # left shift mantissa until it overflows (create an implicit 1)
+                # subtract exponent by the same amount
+                left_shift = mbits - i
+                mantissa_f32 = (mantissa_cmp - (1 << i)) << (
+                    left_shift + MBITS_F32 - mbits
+                )
+                exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
+
+                # we can update this in-place since the values won't overlap
+                # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
+                # thus we use + instead of | here
+                mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = (
+                    exp_biased_f32 + mantissa_f32
+                )
+
+        result = torch.where(denormal_mask, mantissa_lp_int32, result)
+
+    # add sign back
+    sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
+    result = result | sign_f32
+
+    return result.view(torch.float)
+
+# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
+# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
+# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
+# more naturally.
+def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
+    # Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
+    # Output should be [input.size(0, input.size(1) // blocksize] scales
+    output_scales = torch.zeros(
+        (input.size(0), input.size(1) // blocksize),
+        device=input.device,
+        dtype=input_scales.dtype,
+    )
+
+    # Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
+    # happened for offset purposes.
+    # There are K//blocksize scales, padded to groups of 4.
+    num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
+
+    # (Very) slow reference implementation using horrifying loops.
+    for i in range(input.size(0)):
+        for j in range(input.size(1) // blocksize):
+            # which 128x4 tile of scaling factors am I in
+            scale_tile_h = i // 128
+            scale_tile_w = j // 4
+
+            # There are (padded) input_scales.size(1) // 4 tiles along the w dim.
+            # So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
+            tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
+
+            # indices within the tile - use nomenclature directly from cublas docs
+            outer = i % 128  # "outer" in cublas docs
+            inner = j % 4    # "inner" in cublas docs
+
+            # Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
+            #       anyway, and so 1B == 1 value => use offset directly.
+            # Formula directly from cublas docs in 3.1.4.3.2
+            offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
+
+            output_scales[i, j] = input_scales[offset]
+
+    return output_scales
+
+def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
+    # expand scales
+    scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
+
+    # de-scale and convert
+    x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
+    return x_f32.to(torch.bfloat16)
+
+def to_blocked(input_matrix) -> torch.Tensor:
+    """
+    Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
+
+    See:
+        https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
+
+    Args:
+        input_matrix: Input tensor of shape (H, W)
+
+    Returns:
+        Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
+    """
+    rows, cols = input_matrix.shape
+    n_row_blocks = ceil_div(rows, 128)
+    n_col_blocks = ceil_div(cols, 4)
+
+    # Calculate the padded shape
+    padded_rows = n_row_blocks * 128
+    padded_cols = n_col_blocks * 4
+
+    padded = input_matrix
+    # Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
+    if (rows, cols) != (padded_rows, padded_cols):
+        padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
+        padded[:rows, :cols] = input_matrix
+
+    # Rearrange the blocks
+    blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
+    rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
+
+    return rearranged.flatten()
+
+
+def down_size(size):
+    assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
+    return (*size[:-1], size[-1] // 2)
+
+
+def pack_uint4(uint8_data) -> torch.Tensor:
+    # converting to uint8 for operations
+    shape = uint8_data.shape
+    assert shape[-1] % 2 == 0
+    uint8_data = uint8_data.contiguous().view(-1)
+    return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
+
+
+# exponent and mantissa bits of `torch.float4_e2m1fn_x2`
+FP4_EBITS, FP4_MBITS = 2, 1
+
+
+def _bfloat16_to_float4_e2m1fn_x2(x):
+    assert x.dtype == torch.bfloat16
+    x = _f32_to_floatx_unpacked(x.float(), FP4_EBITS, FP4_MBITS)
+    x = pack_uint4(x)
+    x = x.view(torch.float4_e2m1fn_x2)
+    return x
+
+
+# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
+def to_mxfp(
+    data_hp: torch.Tensor,
+    block_size: int = 32,
+    format: str = "mxfp8",
+):
+    assert data_hp.dtype in (
+        torch.bfloat16,
+        torch.float,
+    ), f"{data_hp.dtype} is not supported yet"
+    assert (
+        data_hp.shape[-1] % block_size == 0
+    ), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
+    assert data_hp.is_contiguous(), "unsupported"
+
+    orig_shape = data_hp.shape
+    data_hp = data_hp.reshape(
+        *orig_shape[:-1], orig_shape[-1] // block_size, block_size
+    )
+
+    max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
+
+    data_hp = data_hp.to(torch.float32)
+    max_abs = max_abs.to(torch.float32)
+
+    if format == "mxfp8":
+        F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max  # 448.0
+        max_pos = F8E4M3_MAX
+    elif format == "mxfp4":
+        F4E2M1_MAX = 6.
+        max_pos = F4E2M1_MAX
+
+    # RCEIL
+    def _to_mx_rceil(
+        data_hp: torch.Tensor,
+        max_abs: torch.Tensor,
+        max_pos: float,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        E8M0_EXPONENT_BIAS = 127
+        descale = max_abs / max_pos
+        exponent = torch.where(
+            torch.isnan(descale),
+            0xFF,  # Handle biased exponent for nan
+            # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
+            (
+                torch.clamp(
+                    torch.ceil(torch.log2(descale)),
+                    min=-E8M0_EXPONENT_BIAS,
+                    max=E8M0_EXPONENT_BIAS,
+                )
+                + E8M0_EXPONENT_BIAS
+            ).to(torch.uint8),
+        )
+
+        descale_fp = torch.where(
+            exponent == 0,
+            1.0,
+            torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
+        )
+
+        # scale and saturated cast the data elements to max of target dtype
+        data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
+        return exponent, data_lp
+
+    scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
+
+    # cast to target dtype
+    if format == "mxfp8":
+        data_lp = data_lp.to(torch.float8_e4m3fn)
+        # need to reshape at the end to help inductor fuse things
+        data_lp = data_lp.reshape(orig_shape)
+    elif format == "mxfp4":
+        data_lp = _bfloat16_to_float4_e2m1fn_x2(data_lp.to(torch.bfloat16))
+        final_shape = list(orig_shape)
+        final_shape[-1] //= 2
+        data_lp = data_lp.reshape(final_shape)
+
+    scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
+    scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
+    return scale_e8m0_biased, data_lp
+
+# Source: https://github.com/pytorch/ao/blob/568c1932a16ae9f30d48da214a88dc0013e98ed8/torchao/prototype/moe_training/utils.py#L310
+def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):
+    """
+    Utility function for tests and benchmarks.
+
+    Generates a tensor of length E, containing random values divisible by `multiple_of`,
+    from 0 to M, in sorted order, and where the final value in the tensor is always M.
+    Args:
+        E (int): The length of the tensor.
+        M (int): The maximum value in the tensor.
+    Returns:
+        torch.Tensor: A tensor of length E with the specified properties.
+    """
+    import random
+
+    # Ensure M is divisible by 16
+    if M % multiple_of != 0:
+        raise ValueError(f"M must be divisible by {multiple_of}")
+
+    # Generate a list of possible values
+    possible_values = list(range(multiple_of, M + 1, multiple_of))
+
+    # If E is larger than the number of possible values, raise an error
+    if E > len(possible_values):
+        raise ValueError("E cannot be larger than the number of possible values")
+
+    # Randomly select E - 1 values from the possible values (excluding M)
+    selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
+
+    # Append M to the selected values
+    selected_values = torch.cat((selected_values, torch.tensor([M])))
+
+    # Sort the selected values
+    selected_values, _ = torch.sort(selected_values)
+
+    return selected_values.to(dtype).to(device)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py
new file mode 100644
index 0000000000000000000000000000000000000000..cca291133d3e945c6b42054577a711d781857cac
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py
@@ -0,0 +1,343 @@
+# mypy: ignore-errors
+
+import torch
+from copy import deepcopy
+from torch.utils._pytree import tree_map
+import torch.utils._pytree as pytree
+
+
+# TODO: Move LoggingTensor here.
+from torch.testing._internal.logging_tensor import LoggingTensor
+
+
+# Base class for wrapper-style tensors.
+class WrapperTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, *args, **kwargs):
+        t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
+        if "size" not in kwargs:
+            size = t.size()
+        else:
+            size = kwargs["size"]
+            del kwargs["size"]
+        if "dtype" not in kwargs:
+            kwargs["dtype"] = t.dtype
+        if "layout" not in kwargs:
+            kwargs["layout"] = t.layout
+        if "device" not in kwargs:
+            kwargs["device"] = t.device
+        if "requires_grad" not in kwargs:
+            kwargs["requires_grad"] = False
+        # Ignore memory_format and pin memory for now as I don't know how to
+        # safely access them on a Tensor (if possible??)
+
+        wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
+        wrapper._validate_methods()
+        return wrapper
+
+    @classmethod
+    def get_wrapper_properties(cls, *args, **kwargs):
+        # Should return both an example Tensor and a dictionary of kwargs
+        # to override any of that example Tensor's properly.
+        # This is very similar to the `t.new_*(args)` API
+        raise NotImplementedError("You need to implement get_wrapper_properties")
+
+    def _validate_methods(self):
+        # Skip this if not in debug mode?
+        # Changing these on the python side is wrong as it would not be properly reflected
+        # on the c++ side
+        # This doesn't catch attributes set in the __init__
+        forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
+        for el in forbidden_overrides:
+            if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
+                raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
+                                   f"property {el} but this is not allowed as such change would "
+                                   "not be reflected to c++ callers.")
+
+
+class WrapperTensorWithCustomSizes(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, t, requires_grad=False):
+        return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "sizes"}
+
+    def __init__(self, t, requires_grad=False):
+        self.t = t
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        if kwargs is None:
+            kwargs = {}
+
+        def unwrap(e):
+            return e.t if isinstance(e, WrapperTensorWithCustomSizes) else e
+
+        def wrap(e):
+            return WrapperTensorWithCustomSizes(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"t={self.t}")
+
+
+class WrapperTensorWithCustomStrides(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, t, requires_grad=False):
+        return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "strides"}
+
+    def __init__(self, t, requires_grad=False):
+        self.t = t
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        if kwargs is None:
+            kwargs = {}
+
+        def unwrap(e):
+            return e.t if isinstance(e, WrapperTensorWithCustomStrides) else e
+
+        def wrap(e):
+            return WrapperTensorWithCustomStrides(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"t={self.t}")
+
+
+class DiagTensorBelow(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, diag, requires_grad=False):
+        assert diag.ndim == 1
+        return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
+
+    def __init__(self, diag, requires_grad=False):
+        self.diag = diag
+
+    handled_ops = {}
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        # For everything else, call the handler:
+        fn = cls.handled_ops.get(func.__name__, None)
+        if fn:
+            return fn(*args, **(kwargs or {}))
+        else:
+            # Note that here, because we don't need to provide the autograd formulas
+            # we can have a default "fallback" that creates a plain Tensor based
+            # on the diag elements and calls the func again.
+
+            def unwrap(e):
+                return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
+
+            def wrap(e):
+                if isinstance(e, torch.Tensor) and e.ndim == 1:
+                    return DiagTensorBelow(e)
+                if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
+                    return DiagTensorBelow(e.diag())
+                return e
+
+            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+            return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"diag={self.diag}")
+
+
+class SparseTensor(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
+        assert values.device == indices.device
+        return values, {"size": size, "requires_grad": requires_grad}
+
+    def __init__(self, size, values, indices, requires_grad=False):
+        self.values = values
+        self.indices = indices
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
+
+    def sparse_to_dense(self):
+        res = torch.zeros(self.size(), dtype=self.values.dtype)
+        res[self.indices.unbind(1)] = self.values
+        return res
+
+    @staticmethod
+    def from_dense(t):
+        indices = t.nonzero()
+        values = t[indices.unbind(1)]
+        return SparseTensor(t.size(), values, indices)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        func_name = f"{func.__module__}.{func.__name__}"
+
+        res = cls._try_call_special_impl(func_name, args, kwargs)
+        if res is not NotImplemented:
+            return res
+
+        # Otherwise, use a default implementation that construct dense
+        # tensors and use that to compute values
+        def unwrap(e):
+            return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
+
+        # Wrap back all Tensors into our custom class
+        def wrap(e):
+            # Check for zeros and use that to get indices
+            return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+
+    _SPECIAL_IMPLS = {}
+
+    @classmethod
+    def _try_call_special_impl(cls, func, args, kwargs):
+        if func not in cls._SPECIAL_IMPLS:
+            return NotImplemented
+        return cls._SPECIAL_IMPLS[func](args, kwargs)
+
+
+# Example non-wrapper subclass that stores extra state.
+class NonWrapperTensor(torch.Tensor):
+    def __new__(cls, data):
+        t = torch.Tensor._make_subclass(cls, data)
+        t.extra_state = {
+            'last_func_called': None
+        }
+        return t
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        result = super().__torch_function__(func, types, args, kwargs)
+
+        if isinstance(result, cls):
+            # Do something with the extra state. For the example here, just store the name of the
+            # last function called (skip for deepcopy so the copy has the same extra state).
+            if func is torch.Tensor.__deepcopy__:
+                result.extra_state = deepcopy(args[0].extra_state)
+            else:
+                result.extra_state = {
+                    'last_func_called': func.__name__,
+                }
+
+        return result
+
+    # new_empty() must be defined for deepcopy to work
+    def new_empty(self, shape):
+        return type(self)(torch.empty(shape))
+
+
+# Class used to store info about subclass tensors used in testing.
+class SubclassInfo:
+
+    __slots__ = ['name', 'create_fn', 'closed_under_ops']
+
+    def __init__(self, name, create_fn, closed_under_ops=True):
+        self.name = name
+        self.create_fn = create_fn  # create_fn(shape) -> tensor instance
+        self.closed_under_ops = closed_under_ops
+
+
+# Helper function to create a subclass of the given class and possibly cache sizes / strides.
+def _create_and_access_shape(cls, shape):
+    sub = cls(torch.randn(shape))
+    # NB: Wrapper subclasses with custom dispatched sizes / strides cache this info
+    # on the first call via non-serializable PyCapsules. We purposefully trigger cache
+    # population here for serialization / deepcopy tests to verify that the presence of this
+    # cache info doesn't cause problems.
+    sub.size()
+    sub.stride()
+    return sub
+
+
+subclass_db = {
+    torch.Tensor: SubclassInfo(
+        'base_tensor', create_fn=torch.randn
+    ),
+    NonWrapperTensor: SubclassInfo(
+        'non_wrapper_tensor',
+        create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
+    ),
+    LoggingTensor: SubclassInfo(
+        'logging_tensor',
+        create_fn=lambda shape: LoggingTensor(torch.randn(shape))
+    ),
+    SparseTensor: SubclassInfo(
+        'sparse_tensor',
+        create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
+    ),
+    DiagTensorBelow: SubclassInfo(
+        'diag_tensor_below',
+        create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
+        closed_under_ops=False  # sparse semantics
+    ),
+    WrapperTensorWithCustomSizes: SubclassInfo(
+        'wrapper_with_custom_sizes',
+        create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomSizes, shape),
+        closed_under_ops=False,
+    ),
+    WrapperTensorWithCustomStrides: SubclassInfo(
+        'wrapper_with_custom_strides',
+        create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomStrides, shape),
+        closed_under_ops=False,
+    ),
+}
+
+class SubclassWithTensorFactory(torch.Tensor):
+    @staticmethod
+    def __new__(cls, src):
+        shape = src.shape
+        kwargs = {}
+        kwargs["strides"] = src.stride()
+        kwargs["storage_offset"] = src.storage_offset()
+        kwargs["device"] = src.device
+        kwargs["layout"] = src.layout
+        kwargs["requires_grad"] = src.requires_grad
+        kwargs["dtype"] = src.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+        return out
+
+    def __init__(self, src):
+        self.src = src
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}"
+
+    def __tensor_flatten__(self):
+        return ["src"], None
+
+    @classmethod
+    def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
+        src = inner_tensors["src"]
+        return cls(src)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+
+        def _fn(x):
+            return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
+
+        _args = pytree.tree_map_only(cls, _fn, args)
+        _kwargs = pytree.tree_map_only(cls, _fn, kwargs)
+
+        _out = func(*_args, **_kwargs)
+
+        _out_flat, _out_spec = pytree.tree_flatten(_out)
+
+        out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
+        return pytree.tree_unflatten(out_flat, _out_spec)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a0b3c3a537116daa3be625a7dea1d6f60acd647
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py
@@ -0,0 +1,5882 @@
+# mypy: allow-untyped-defs
+
+r"""Importing this file must **not** initialize CUDA context. test_distributed
+relies on this assumption to properly run. This means that when this is imported
+no CUDA calls shall be made, including torch.cuda.device_count(), etc.
+
+torch.testing._internal.common_cuda.py can freely initialize CUDA context when imported.
+"""
+
+import argparse
+import contextlib
+import copy
+import ctypes
+import errno
+import functools
+import gc
+import hashlib
+import inspect
+import io
+import json
+import logging
+import math
+import operator
+import os
+import pathlib
+import platform
+import random
+import re
+import shutil
+import signal
+import socket
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import types
+import unittest
+import warnings
+from collections.abc import Mapping, Sequence
+from contextlib import closing, contextmanager
+from copy import deepcopy
+from dataclasses import dataclass
+from enum import Enum
+from functools import partial, wraps
+from itertools import product, chain
+from pathlib import Path
+from statistics import mean
+from typing import (
+    Any,
+    Optional,
+    TypeVar,
+    Union,
+)
+from collections.abc import Callable
+from collections.abc import Iterable, Iterator
+from unittest.mock import MagicMock
+
+import expecttest
+import numpy as np
+
+import __main__  # type: ignore[import]
+import torch
+import torch.backends.cudnn
+import torch.backends.mkl
+import torch.backends.mps
+import torch.backends.xnnpack
+import torch.cuda
+from torch import Tensor
+from torch._C import ScriptDict, ScriptList  # type: ignore[attr-defined]
+from torch._utils_internal import get_writable_path
+from torch._logging.scribe import open_source_signpost
+from torch.nn import (
+    ModuleDict,
+    ModuleList,
+    ParameterDict,
+    ParameterList,
+    Sequential,
+)
+from torch.onnx import (
+    register_custom_op_symbolic,
+    unregister_custom_op_symbolic,
+)
+from torch.testing import make_tensor
+from torch.testing._comparison import (
+    BooleanPair,
+    NonePair,
+    NumberPair,
+    Pair,
+    TensorLikePair,
+)
+from torch.testing._comparison import not_close_error_metas
+from torch.testing._internal.common_dtype import get_all_dtypes
+from torch.utils._import_utils import _check_module_exists
+import torch.utils._pytree as pytree
+from torch.utils import cpp_extension
+try:
+    import pytest  # type: ignore[import-not-found]
+    has_pytest = True
+except ImportError:
+    has_pytest = False
+
+SEED = 1234
+MI350_ARCH = ("gfx950",)
+MI300_ARCH = ("gfx942",)
+MI200_ARCH = ("gfx90a")
+NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
+NAVI3_ARCH = ("gfx1100", "gfx1101")
+NAVI4_ARCH = ("gfx1200", "gfx1201")
+
+class ProfilingMode(Enum):
+    LEGACY = 1
+    SIMPLE = 2
+    PROFILING = 3
+
+# Set by parse_cmd_line_args() if called
+DISABLED_TESTS_FILE = ""
+GRAPH_EXECUTOR : Optional[ProfilingMode] = None
+LOG_SUFFIX = ""
+PYTEST_SINGLE_TEST = ""
+REPEAT_COUNT = 0
+RERUN_DISABLED_TESTS = False
+RUN_PARALLEL = 0
+SHOWLOCALS = False
+SLOW_TESTS_FILE = ""
+TEST_BAILOUTS = False
+TEST_DISCOVER = False
+TEST_IN_SUBPROCESS = False
+TEST_SAVE_XML = ""
+UNITTEST_ARGS : list[str] = []
+USE_PYTEST = False
+
+def is_navi3_arch():
+    if torch.cuda.is_available():
+        prop = torch.cuda.get_device_properties(0)
+        gfx_arch = prop.gcnArchName.split(":")[0]
+        if gfx_arch in NAVI3_ARCH:
+            return True
+    return False
+
+def freeze_rng_state(*args, **kwargs):
+    return torch.testing._utils.freeze_rng_state(*args, **kwargs)
+
+
+# Class to keep track of test flags configurable by environment variables.
+# Flags set here are intended to be read-only and should not be modified after
+# definition.
+# TODO: Expand this class to handle arbitrary settings in addition to boolean flags?
+class TestEnvironment:
+    # Set of env vars to set for the repro command that is output on test failure.
+    # Specifically, this includes env vars that are set to non-default values and
+    # are not implied. Maps from env var name -> value (int)
+    repro_env_vars: dict = {}
+
+    # Defines a flag usable throughout the test suite, determining its value by querying
+    # the specified environment variable.
+    #
+    # Args:
+    #     name (str): The name of the flag. A global variable with this name will be set
+    #         for convenient access throughout the test suite.
+    #     env_var (str): The name of the primary environment variable from which to
+    #         determine the value of this flag. If this is None or the environment variable
+    #         is unset, the default value will be used unless otherwise implied (see
+    #         implied_by_fn). Default: None
+    #     default (bool): The default value to use for the flag if unset by the environment
+    #         variable and unimplied. Default: False
+    #     include_in_repro (bool): Indicates whether this flag should be included in the
+    #         repro command that is output on test failure (i.e. whether it is possibly
+    #         relevant to reproducing the test failure). Default: True
+    #     enabled_fn (Callable): Callable returning whether the flag should be enabled
+    #         given the environment variable value and the default value. Default: Lambda
+    #         requiring "0" to disable if on by default OR "1" to enable if off by default.
+    #     implied_by_fn (Callable): Thunk returning a bool to imply this flag as enabled
+    #         by something outside of its primary environment variable setting. For example,
+    #         this can be useful if the value of another environment variable implies the flag
+    #         as enabled. Default: Lambda returning False to indicate no implications.
+    @staticmethod
+    def def_flag(
+        name,
+        env_var=None,
+        default=False,
+        include_in_repro=True,
+        enabled_fn=lambda env_var_val, default: (
+            (env_var_val != "0") if default else (env_var_val == "1")),
+        implied_by_fn=lambda: False,
+    ):
+        enabled = default
+        env_var_val = None
+        if env_var is not None:
+            env_var_val = os.getenv(env_var)
+            enabled = enabled_fn(env_var_val, default)
+        implied = implied_by_fn()
+        enabled = enabled or implied
+        if include_in_repro and (env_var is not None) and (enabled != default) and not implied:
+            TestEnvironment.repro_env_vars[env_var] = env_var_val
+
+        # export flag globally for convenience
+        assert name not in globals(), f"duplicate definition of flag '{name}'"
+        globals()[name] = enabled
+        return enabled
+
+    # Defines a setting usable throughout the test suite, determining its value by querying
+    # the specified environment variable. This differs from a flag in that it's not restricted
+    # to a boolean value.
+    #
+    # Args:
+    #     name (str): The name of the setting. A global variable with this name will be set
+    #         for convenient access throughout the test suite.
+    #     env_var (str): The name of the primary environment variable from which to
+    #         determine the value of this setting. If this is None or the environment variable
+    #         is unset, the default value will be used. Default: None
+    #     default (Any): The default value to use for the setting if unset by the environment
+    #         variable. Default: None
+    #     include_in_repro (bool): Indicates whether this setting should be included in the
+    #         repro command that is output on test failure (i.e. whether it is possibly
+    #         relevant to reproducing the test failure). Default: True
+    #     parse_fn (Callable): Callable parsing the env var string. Default value just uses
+    #         the string itself.
+    @staticmethod
+    def def_setting(
+        name,
+        env_var=None,
+        default=None,
+        include_in_repro=True,
+        parse_fn=lambda maybe_val_str: maybe_val_str,
+    ):
+        value = default if env_var is None else os.getenv(env_var)
+        value = parse_fn(value)
+        if include_in_repro and (value != default):
+            TestEnvironment.repro_env_vars[env_var] = value
+
+        # export setting globally for convenience
+        assert name not in globals(), f"duplicate definition of setting '{name}'"
+        globals()[name] = value
+        return value
+
+    # Returns a string prefix usable to set environment variables for any test
+    # settings that should be explicitly set to match this instantiation of the
+    # test suite.
+    # Example: "PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_ROCM=1"
+    @staticmethod
+    def repro_env_var_prefix() -> str:
+        return " ".join([f"{env_var}={value}"
+                         for env_var, value in TestEnvironment.repro_env_vars.items()])
+
+
+log = logging.getLogger(__name__)
+torch.backends.disable_global_flags()
+
+FILE_SCHEMA = "file://"
+if sys.platform == 'win32':
+    FILE_SCHEMA = "file:///"
+
+# NB: This flag differs semantically from others in that setting the env var to any
+# non-empty value will cause it to be true:
+#   CI=1, CI="true", CI=0, etc. all set the flag to be true.
+#   CI= and an unset CI set the flag to be false.
+# GitHub sets the value to CI="true" to enable it.
+IS_CI: bool = TestEnvironment.def_flag(
+    "IS_CI",
+    env_var="CI",
+    include_in_repro=False,
+    enabled_fn=lambda env_var_value, _: bool(env_var_value),
+)
+IS_SANDCASTLE: bool = TestEnvironment.def_flag(
+    "IS_SANDCASTLE",
+    env_var="SANDCASTLE",
+    implied_by_fn=lambda: os.getenv("TW_JOB_USER") == "sandcastle",
+    include_in_repro=False,
+)
+IN_RE_WORKER: bool = os.environ.get("INSIDE_RE_WORKER") is not None
+
+_is_fbcode_default = (
+    hasattr(torch._utils_internal, "IS_FBSOURCE") and
+    torch._utils_internal.IS_FBSOURCE
+)
+
+IS_FBCODE: bool = TestEnvironment.def_flag(
+    "IS_FBCODE",
+    env_var="PYTORCH_TEST_FBCODE",
+    default=_is_fbcode_default,
+    include_in_repro=False,
+)
+IS_REMOTE_GPU: bool = TestEnvironment.def_flag(
+    "IS_REMOTE_GPU",
+    env_var="PYTORCH_TEST_REMOTE_GPU",
+    include_in_repro=False,
+)
+
+DISABLE_RUNNING_SCRIPT_CHK: bool = TestEnvironment.def_flag(
+    "DISABLE_RUNNING_SCRIPT_CHK",
+    env_var="PYTORCH_DISABLE_RUNNING_SCRIPT_CHK",
+    include_in_repro=False,
+)
+# NB: enabled by default unless in an fbcode context.
+PRINT_REPRO_ON_FAILURE: bool = TestEnvironment.def_flag(
+    "PRINT_REPRO_ON_FAILURE",
+    env_var="PYTORCH_PRINT_REPRO_ON_FAILURE",
+    default=(not IS_FBCODE),
+    include_in_repro=False,
+)
+
+# possibly restrict OpInfo tests to a single sample input
+OPINFO_SAMPLE_INPUT_INDEX: Optional[int] = TestEnvironment.def_setting(
+    "OPINFO_SAMPLE_INPUT_INDEX",
+    env_var="PYTORCH_OPINFO_SAMPLE_INPUT_INDEX",
+    default=None,
+    # Don't include the env var value in the repro command because the info will
+    # be queried from the tracked sample input instead
+    include_in_repro=False,
+    parse_fn=lambda val: None if val is None else int(val),
+)
+
+DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
+DEFAULT_SLOW_TESTS_FILE = 'slow_tests.json'
+
+disabled_tests_dict = {}
+slow_tests_dict = {}
+
+def maybe_load_json(filename):
+    if os.path.isfile(filename):
+        with open(filename) as fp:
+            return json.load(fp)
+    log.warning("Attempted to load json file '%s' but it does not exist.", filename)
+    return {}
+
+# set them here in case the tests are running in a subprocess that doesn't call run_tests
+if os.getenv("SLOW_TESTS_FILE", ""):
+    slow_tests_dict = maybe_load_json(os.getenv("SLOW_TESTS_FILE", ""))
+if os.getenv("DISABLED_TESTS_FILE", ""):
+    disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", ""))
+
+NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', 'mtia', torch._C._get_privateuse1_backend_name())
+
+# used for managing devices testing for torch profiler UTs
+# for now cpu, cuda and xpu are added for testing torch profiler UTs
+DEVICE_LIST_SUPPORT_PROFILING_TEST = ('cpu', 'cuda', 'xpu')
+ALLOW_XPU_PROFILING_TEST = True
+
+check_names = ['orin', 'concord', 'galen', 'xavier', 'nano', 'jetson', 'tegra', 'thor']
+IS_JETSON = any(name in platform.platform() for name in check_names)
+
+def gcIfJetson(fn):
+    # Irregular Jetson host/device memory setup requires cleanup to avoid tests being killed
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if IS_JETSON:
+            gc.collect()
+            torch.cuda.empty_cache()
+        fn(*args, **kwargs)
+    return wrapper
+
+# Tries to extract the current test function by crawling the stack.
+# If unsuccessful, return None.
+def extract_test_fn() -> Optional[Callable]:
+    try:
+        stack = inspect.stack()
+        for frame_info in stack:
+            frame = frame_info.frame
+            if "self" not in frame.f_locals:
+                continue
+            self_val = frame.f_locals["self"]
+            if isinstance(self_val, unittest.TestCase):
+                test_id = self_val.id()
+                *_, cls_name, test_name = test_id.rsplit('.', 2)
+                if cls_name == type(self_val).__name__ and test_name.startswith("test"):
+                    test_fn = getattr(self_val, test_name).__func__
+                    return test_fn
+    except Exception:
+        pass
+    return None
+
+# Contains tracked input data useful for debugging purposes
+@dataclass
+class TrackedInput:
+    index: int
+    val: Any
+    type_desc: str
+
+# Attempt to pull out tracked input information from the test function.
+# A TrackedInputIter is used to insert this information.
+def get_tracked_input() -> Optional[TrackedInput]:
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return None
+    return getattr(test_fn, "tracked_input", None)
+
+def clear_tracked_input() -> None:
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return
+    if not hasattr(test_fn, "tracked_input"):
+        return
+    test_fn.tracked_input = None  # type: ignore[attr-defined]
+
+# Wraps an iterator and tracks the most recent value the iterator produces
+# for debugging purposes. Tracked values are stored on the test function.
+class TrackedInputIter:
+    def __init__(
+        self,
+        child_iter,
+        input_type_desc,
+        item_callback=None,
+        track_callback=None,
+        set_seed=True,
+        restrict_to_index=None
+    ):
+        self.child_iter = enumerate(child_iter)
+        # Input type describes the things we're tracking (e.g. "sample input", "error input").
+        self.input_type_desc = input_type_desc
+        # NB: The two types of callbacks below exist because the thing we want to track isn't
+        # always the same as the thing we want returned from the iterator. An example of this
+        # is ErrorInput, which we want returned from the iterator, but which contains a
+        # SampleInput that we want to track.
+        # Item callback is run on each (iterated thing, index) to get the thing to return.
+        self.item_callback = item_callback
+        if self.item_callback is None:
+            self.item_callback = lambda x, i: x
+        # Track callback is run on each iterated thing to get the thing to track.
+        self.track_callback = track_callback
+        if self.track_callback is None:
+            self.track_callback = lambda x: x
+        self.test_fn = extract_test_fn()
+        # Indicates whether the random seed should be set before each call to the iterator
+        self.set_seed = set_seed
+        # Indicates that iteration should be restricted to only the provided index.
+        # If None, no restriction is done
+        self.restrict_to_index = restrict_to_index
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        while True:
+            if self.set_seed:
+                # use a test-name-specific hash for the seed if possible
+                seed = (
+                    int.from_bytes(hashlib.sha256(
+                        self.test_fn.__qualname__.encode("utf-8")).digest()[:4], 'little')
+                    if self.test_fn is not None else SEED
+                )
+                set_rng_seed(seed)
+
+            # allow StopIteration to bubble up
+            input_idx, input_val = next(self.child_iter)
+            if (self.restrict_to_index is None) or (input_idx == self.restrict_to_index):
+                break
+
+        self._set_tracked_input(
+            TrackedInput(
+                index=input_idx, val=self.track_callback(input_val), type_desc=self.input_type_desc
+            )
+        )
+        return self.item_callback(input_val, input_idx)
+
+    def _set_tracked_input(self, tracked_input: TrackedInput):
+        if self.test_fn is None:
+            return
+        if not hasattr(self.test_fn, "tracked_input"):
+            return
+        self.test_fn.tracked_input = tracked_input  # type: ignore[attr-defined]
+
+class _TestParametrizer:
+    """
+    Decorator class for parametrizing a test function, yielding a set of new tests spawned
+    from the original generic test, each specialized for a specific set of test inputs. For
+    example, parametrizing a test across the set of ops will result in a test function per op.
+
+    The decision of how to parametrize / what to parametrize over is intended to be implemented
+    by each derived class.
+
+    In the details, the decorator adds a 'parametrize_fn' property to the test function. This function
+    is intended to be called later by one of:
+      * Device-specific test instantiation via instantiate_device_type_tests(). Note that for this
+        case there is no need to explicitly parametrize over device type, as that is handled separately.
+      * Device-agnostic parametrized test instantiation via instantiate_parametrized_tests().
+
+    If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new
+    composite 'parametrize_fn' will be created that generates tests with the product of the parameters
+    generated by the old and new parametrize_fns. This allows for convenient composability of decorators.
+    """
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        """
+        Parametrizes the given test function across whatever dimension is specified by the derived class.
+        Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
+        ops, all modules, or all ops + their associated dtypes.
+
+        Args:
+            test (fn): Test function to parametrize over
+            generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+            device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None
+                if the tests are not part of a device-specific set
+
+        Returns:
+            Generator object returning 4-tuples of:
+                test (fn): Parametrized test function; must support a device arg and args for any params
+                test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to
+                    the base name of the test
+                param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
+                decorator_fn (callable): Callable[[Dict], List] for list of decorators to apply given param_kwargs
+        """
+        raise NotImplementedError
+
+    def __call__(self, fn):
+        if hasattr(fn, 'parametrize_fn'):
+            # Do composition with the product of args.
+            old_parametrize_fn = fn.parametrize_fn
+            new_parametrize_fn = self._parametrize_test
+            fn.parametrize_fn = compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn)
+        else:
+            fn.parametrize_fn = self._parametrize_test
+        return fn
+
+
+def compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn):
+    """
+    Returns a parametrize_fn that parametrizes over the product of the parameters handled
+    by the given parametrize_fns. Each given parametrize_fn should each have the signature
+    f(test, generic_cls, device_cls).
+
+    The test names will be a combination of the names produced by the parametrize_fns in
+    "_" order. This order is done to match intuition for constructed names
+    when composing multiple decorators; the names will be built in top to bottom order when stacking
+    parametrization decorators.
+
+    Args:
+        old_parametrize_fn (callable) - First parametrize_fn to compose.
+        new_parametrize_fn (callable) - Second parametrize_fn to compose.
+    """
+
+    def composite_fn(test, generic_cls, device_cls,
+                     old_parametrize_fn=old_parametrize_fn,
+                     new_parametrize_fn=new_parametrize_fn):
+        old_tests = list(old_parametrize_fn(test, generic_cls, device_cls))
+        for (old_test, old_test_name, old_param_kwargs, old_dec_fn) in old_tests:
+            for (new_test, new_test_name, new_param_kwargs, new_dec_fn) in \
+                    new_parametrize_fn(old_test, generic_cls, device_cls):
+                redundant_params = set(old_param_kwargs.keys()).intersection(new_param_kwargs.keys())
+                if redundant_params:
+                    raise RuntimeError('Parametrization over the same parameter by multiple parametrization '
+                                       f'decorators is not supported. For test "{test.__name__}", the following parameters '
+                                       f'are handled multiple times: {redundant_params}')
+                full_param_kwargs = {**old_param_kwargs, **new_param_kwargs}
+                merged_test_name = '{}{}{}'.format(new_test_name,
+                                                   '_' if old_test_name != '' and new_test_name != '' else '',
+                                                   old_test_name)
+
+                def merged_decorator_fn(param_kwargs, old_dec_fn=old_dec_fn, new_dec_fn=new_dec_fn):
+                    return list(old_dec_fn(param_kwargs)) + list(new_dec_fn(param_kwargs))
+
+                yield (new_test, merged_test_name, full_param_kwargs, merged_decorator_fn)
+
+    return composite_fn
+
+
+def instantiate_parametrized_tests(generic_cls):
+    """
+    Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a
+    decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by
+    parametrized tests with specialized names. This should be used instead of
+    instantiate_device_type_tests() if the test class contains device-agnostic tests.
+
+    You can also use it as a class decorator. E.g.
+
+    ```
+    @instantiate_parametrized_tests
+    class TestFoo(TestCase):
+        ...
+    ```
+
+    Args:
+        generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+    """
+    for attr_name in tuple(dir(generic_cls)):
+        class_attr = getattr(generic_cls, attr_name)
+        if not hasattr(class_attr, 'parametrize_fn'):
+            continue
+
+        # Remove the generic test from the test class.
+        delattr(generic_cls, attr_name)
+
+        # Add parametrized tests to the test class.
+        def instantiate_test_helper(cls, name, test, param_kwargs):
+            @wraps(test)
+            def instantiated_test(self, param_kwargs=param_kwargs):
+                test(self, **param_kwargs)
+
+            assert not hasattr(generic_cls, name), f"Redefinition of test {name}"
+            setattr(generic_cls, name, instantiated_test)
+
+        for (test, test_suffix, param_kwargs, decorator_fn) in class_attr.parametrize_fn(
+                class_attr, generic_cls=generic_cls, device_cls=None):
+            full_name = f'{test.__name__}_{test_suffix}'
+
+            # Apply decorators based on full param kwargs.
+            for decorator in decorator_fn(param_kwargs):
+                test = decorator(test)
+
+            instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs)
+    return generic_cls
+
+
+class subtest:
+    """
+    Explicit subtest case for use with test parametrization.
+    Allows for explicit naming of individual subtest cases as well as applying
+    decorators to the parametrized test.
+
+    Args:
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name (str): Optional name to use for the test.
+        decorators (iterable): Iterable of decorators to apply to the generated test.
+    """
+    __slots__ = ['arg_values', 'name', 'decorators']
+
+    def __init__(self, arg_values, name=None, decorators=None):
+        self.arg_values = arg_values
+        self.name = name
+        self.decorators = decorators if decorators else []
+
+
+class parametrize(_TestParametrizer):
+    """
+    Decorator for applying generic test parametrizations.
+
+    The interface for this decorator is modeled after `@pytest.mark.parametrize`.
+    Basic usage between this decorator and pytest's is identical. The first argument
+    should be a string containing comma-separated names of parameters for the test, and
+    the second argument should be an iterable returning values or tuples of values for
+    the case of multiple parameters.
+
+    Beyond this basic usage, the decorator provides some additional functionality that
+    pytest does not.
+
+    1. Parametrized tests end up as generated test functions on unittest test classes.
+    Since this differs from how pytest works, this decorator takes on the additional
+    responsibility of naming these test functions. The default test names consists of
+    the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"),
+    but custom names can be defined using `name_fn` or the `subtest` structure (see below).
+
+    2. The decorator specially handles parameter values of type `subtest`, which allows for
+    more fine-grained control over both test naming and test execution. In particular, it can
+    be used to tag subtests with explicit test names or apply arbitrary decorators (see examples
+    below).
+
+    Examples::
+
+        @parametrize("x", range(5))
+        def test_foo(self, x):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+        def test_bar(self, x, y):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')],
+                     name_fn=lambda x, y: '{}_{}'.format(x, y))
+        def test_bar_custom_names(self, x, y):
+            ...
+
+        @parametrize("x, y", [subtest((1, 2), name='double'),
+                              subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]),
+                              subtest((1, 4), name='quadruple')])
+        def test_baz(self, x, y):
+            ...
+
+    To actually instantiate the parametrized tests, one of instantiate_parametrized_tests() or
+    instantiate_device_type_tests() should be called. The former is intended for test classes
+    that contain device-agnostic tests, while the latter should be used for test classes that
+    contain device-specific tests. Both support arbitrary parametrizations using the decorator.
+
+    Args:
+        arg_str (str): String of arg names separate by commas (e.g. "x,y").
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name_fn (Callable): Optional function that takes in parameters and returns subtest name.
+    """
+    def __init__(self, arg_str, arg_values, name_fn=None):
+        self.arg_names: list[str] = [s.strip() for s in arg_str.split(',') if s != '']
+        self.arg_values = arg_values
+        self.name_fn = name_fn
+
+    def _formatted_str_repr(self, idx, name, value):
+        """ Returns a string representation for the given arg that is suitable for use in test function names. """
+        if isinstance(value, torch.dtype):
+            return dtype_name(value)
+        elif isinstance(value, torch.device):
+            return str(value)
+        # Can't use isinstance as it would cause a circular import
+        elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}:
+            return value.formatted_name
+        elif isinstance(value, (int, float, str)):
+            return f"{name}_{str(value).replace('.', '_')}"
+        else:
+            return f"{name}{idx}"
+
+    def _default_subtest_name(self, idx, values):
+        return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values, strict=True)])
+
+    def _get_subtest_name(self, idx, values, explicit_name=None):
+        if explicit_name:
+            subtest_name = explicit_name
+        elif self.name_fn:
+            subtest_name = self.name_fn(*values)
+        else:
+            subtest_name = self._default_subtest_name(idx, values)
+        return subtest_name
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if len(self.arg_names) == 0:
+            # No additional parameters needed for the test.
+            test_name = ''
+            yield (test, test_name, {}, lambda _: [])
+        else:
+            # Each "values" item is expected to be either:
+            # * A tuple of values with one for each arg. For a single arg, a single item is expected.
+            # * A subtest instance with arg_values matching the previous.
+            values = check_exhausted_iterator = object()
+            for idx, values in enumerate(self.arg_values):
+                maybe_name = None
+
+                decorators: list[Any] = []
+                if isinstance(values, subtest):
+                    sub = values
+                    values = sub.arg_values
+                    maybe_name = sub.name
+
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        return test(*args, **kwargs)
+
+                    decorators = sub.decorators
+                    gen_test = test_wrapper
+                else:
+                    gen_test = test
+
+                values = list(values) if len(self.arg_names) > 1 else [values]  # type: ignore[call-overload]
+                if len(values) != len(self.arg_names):
+                    raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} '
+                                       f'values and {len(self.arg_names)} names for test "{test.__name__}"')
+
+                param_kwargs = dict(zip(self.arg_names, values, strict=True))
+
+                test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name)
+
+                def decorator_fn(_, decorators=decorators):
+                    return decorators
+
+                yield (gen_test, test_name, param_kwargs, decorator_fn)
+
+            if values is check_exhausted_iterator:
+                raise ValueError(f'{test}: An empty arg_values was passed to @parametrize. '
+                                 'Note that this may result from reuse of a generator.')
+
+
+class reparametrize(_TestParametrizer):
+    """
+    Decorator for adjusting the way an existing parametrizer operates. This class runs
+    the given adapter_fn on each parametrization produced by the given parametrizer,
+    allowing for on-the-fly parametrization more flexible than the default,
+    product-based composition that occurs when stacking parametrization decorators.
+
+    If the adapter_fn returns None for a given test parametrization, that parametrization
+    will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of
+    modified parametrizations, with tweaked test names and parameter kwargs.
+
+    Examples::
+
+        def include_is_even_arg(test_name, param_kwargs):
+            x = param_kwargs["x"]
+            is_even = x % 2 == 0
+            new_param_kwargs = dict(param_kwargs)
+            new_param_kwargs["is_even"] = is_even
+            is_even_suffix = "_even" if is_even else "_odd"
+            new_test_name = f"{test_name}{is_even_suffix}"
+            yield (new_test_name, new_param_kwargs)
+
+        ...
+
+        @reparametrize(parametrize("x", range(5)), include_is_even_arg)
+        def test_foo(self, x, is_even):
+            ...
+
+        def exclude_odds(test_name, param_kwargs):
+            x = param_kwargs["x"]
+            is_even = x % 2 == 0
+            yield None if not is_even else (test_name, param_kwargs)
+
+        ...
+
+        @reparametrize(parametrize("x", range(5)), exclude_odds)
+        def test_bar(self, x):
+            ...
+
+    """
+    def __init__(self, parametrizer, adapter_fn):
+        self.parametrizer = parametrizer
+        self.adapter_fn = adapter_fn
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        for (gen_test, test_name, param_kwargs, decorator_fn) in \
+                self.parametrizer._parametrize_test(test, generic_cls, device_cls):
+            adapted = self.adapter_fn(test_name, param_kwargs)
+            if adapted is not None:
+                for adapted_item in adapted:
+                    if adapted_item is not None:
+                        new_test_name, new_param_kwargs = adapted_item
+                        yield (gen_test, new_test_name, new_param_kwargs, decorator_fn)
+
+
+class decorateIf(_TestParametrizer):
+    """
+    Decorator for applying parameter-specific conditional decoration.
+    Composes with other test parametrizers (e.g. @modules, @ops, @parametrize, etc.).
+
+    Examples::
+
+        @decorateIf(unittest.skip, lambda params: params["x"] == 2)
+        @parametrize("x", range(5))
+        def test_foo(self, x):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+        @decorateIf(
+            unittest.expectedFailure,
+            lambda params: params["x"] == 3 and params["y"] == "baz"
+        )
+        def test_bar(self, x, y):
+            ...
+
+        @decorateIf(
+            unittest.expectedFailure,
+            lambda params: params["op"].name == "add" and params["dtype"] == torch.float16
+        )
+        @ops(op_db)
+        def test_op_foo(self, device, dtype, op):
+            ...
+
+        @decorateIf(
+            unittest.skip,
+            lambda params: params["module_info"].module_cls is torch.nn.Linear and \
+                params["device"] == "cpu"
+        )
+        @modules(module_db)
+        def test_module_foo(self, device, dtype, module_info):
+            ...
+
+    Args:
+        decorator: Test decorator to apply if the predicate is satisfied.
+        predicate_fn (Callable): Function taking in a dict of params and returning a boolean
+            indicating whether the decorator should be applied or not.
+    """
+    def __init__(self, decorator, predicate_fn):
+        self.decorator = decorator
+        self.predicate_fn = predicate_fn
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+
+        # Leave test as-is and return the appropriate decorator_fn.
+        def decorator_fn(params, decorator=self.decorator, predicate_fn=self.predicate_fn):
+            if predicate_fn(params):
+                return [decorator]
+            else:
+                return []
+
+        @wraps(test)
+        def test_wrapper(*args, **kwargs):
+            return test(*args, **kwargs)
+
+        test_name = ''
+        yield (test_wrapper, test_name, {}, decorator_fn)
+
+
+def cppProfilingFlagsToProfilingMode():
+    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    torch._C._jit_set_profiling_executor(old_prof_exec_state)
+    torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+    if old_prof_exec_state:
+        if old_prof_mode_state:
+            return ProfilingMode.PROFILING
+        else:
+            return ProfilingMode.SIMPLE
+    else:
+        return ProfilingMode.LEGACY
+
+@contextmanager
+def enable_profiling_mode_for_profiling_tests():
+    old_prof_exec_state = False
+    old_prof_mode_state = False
+    assert GRAPH_EXECUTOR
+    if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+        old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+        old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    try:
+        yield
+    finally:
+        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+            torch._C._jit_set_profiling_executor(old_prof_exec_state)
+            torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+@contextmanager
+def enable_profiling_mode():
+    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_profiling_executor(old_prof_exec_state)
+        torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+@contextmanager
+def num_profiled_runs(num_runs):
+    old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_num_profiled_runs(old_num_runs)
+
+func_call = torch._C.ScriptFunction.__call__
+meth_call = torch._C.ScriptMethod.__call__
+
+def prof_callable(callable, *args, **kwargs):
+    if 'profile_and_replay' in kwargs:
+        del kwargs['profile_and_replay']
+        assert GRAPH_EXECUTOR
+        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+            with enable_profiling_mode_for_profiling_tests():
+                callable(*args, **kwargs)
+                return callable(*args, **kwargs)
+
+    return callable(*args, **kwargs)
+
+def raise_on_run_directly(file_to_call):
+    raise RuntimeError("This test file is not meant to be run directly, "
+                       f"use:\n\n\tpython {file_to_call} TESTNAME\n\n"
+                       "instead.")
+
+def prof_func_call(*args, **kwargs):
+    return prof_callable(func_call, *args, **kwargs)
+
+def prof_meth_call(*args, **kwargs):
+    return prof_callable(meth_call, *args, **kwargs)
+
+torch._C.ScriptFunction.__call__ = prof_func_call  # type: ignore[method-assign]
+torch._C.ScriptMethod.__call__ = prof_meth_call  # type: ignore[method-assign]
+
+def _get_test_report_path():
+    # allow users to override the test file location. We need this
+    # because the distributed tests run the same test file multiple
+    # times with different configurations.
+    override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
+    test_source = override if override is not None else 'python-unittest'
+    return os.path.join('test-reports', test_source)
+
+def parse_cmd_line_args():
+    global DISABLED_TESTS_FILE
+    global GRAPH_EXECUTOR
+    global LOG_SUFFIX
+    global PYTEST_SINGLE_TEST
+    global REPEAT_COUNT
+    global RERUN_DISABLED_TESTS
+    global RUN_PARALLEL
+    global SHOWLOCALS
+    global SLOW_TESTS_FILE
+    global TEST_BAILOUTS
+    global TEST_DISCOVER
+    global TEST_IN_SUBPROCESS
+    global TEST_SAVE_XML
+    global UNITTEST_ARGS
+    global USE_PYTEST
+
+    is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
+    parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
+    parser.add_argument('--subprocess', action='store_true',
+                        help='whether to run each test in a subprocess')
+    parser.add_argument('--accept', action='store_true')
+    parser.add_argument('--jit-executor', '--jit_executor', type=str)
+    parser.add_argument('--repeat', type=int, default=1)
+    parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
+    parser.add_argument('--use-pytest', action='store_true')
+    parser.add_argument('--save-xml', nargs='?', type=str,
+                        const=_get_test_report_path(),
+                        default=_get_test_report_path() if IS_CI else None)
+    parser.add_argument('--discover-tests', action='store_true')
+    parser.add_argument('--log-suffix', type=str, default="")
+    parser.add_argument('--run-parallel', type=int, default=1)
+    parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
+    parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
+    parser.add_argument('--rerun-disabled-tests', action='store_true')
+    parser.add_argument('--pytest-single-test', type=str, nargs=1)
+    parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
+
+# Only run when -h or --help flag is active to display both unittest and parser help messages.
+    def run_unittest_help(argv):
+        unittest.main(argv=argv)
+
+    if '-h' in sys.argv or '--help' in sys.argv:
+        help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
+        help_thread.start()
+        help_thread.join()
+
+    args, remaining = parser.parse_known_args()
+    if args.jit_executor == 'legacy':
+        GRAPH_EXECUTOR = ProfilingMode.LEGACY
+    elif args.jit_executor == 'profiling':
+        GRAPH_EXECUTOR = ProfilingMode.PROFILING
+    elif args.jit_executor == 'simple':
+        GRAPH_EXECUTOR = ProfilingMode.SIMPLE
+    else:
+        # infer flags based on the default settings
+        GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
+
+    RERUN_DISABLED_TESTS = args.rerun_disabled_tests
+
+    SLOW_TESTS_FILE = args.import_slow_tests
+    DISABLED_TESTS_FILE = args.import_disabled_tests
+    LOG_SUFFIX = args.log_suffix
+    RUN_PARALLEL = args.run_parallel
+    TEST_BAILOUTS = args.test_bailouts
+    USE_PYTEST = args.use_pytest
+    PYTEST_SINGLE_TEST = args.pytest_single_test
+    TEST_DISCOVER = args.discover_tests
+    TEST_IN_SUBPROCESS = args.subprocess
+    TEST_SAVE_XML = args.save_xml
+    REPEAT_COUNT = args.repeat
+    SHOWLOCALS = args.showlocals
+    if not getattr(expecttest, "ACCEPT", False):
+        expecttest.ACCEPT = args.accept
+    UNITTEST_ARGS = [sys.argv[0]] + remaining
+
+    set_rng_seed()
+
+
+def wait_for_process(p, timeout=None):
+    try:
+        return p.wait(timeout=timeout)
+    except KeyboardInterrupt:
+        # Give `p` a chance to handle KeyboardInterrupt. Without this,
+        # `pytest` can't print errors it collected so far upon KeyboardInterrupt.
+        exit_status = p.wait(timeout=5)
+        if exit_status is not None:
+            return exit_status
+        else:
+            p.kill()
+            raise
+    except subprocess.TimeoutExpired:
+        # send SIGINT to give pytest a chance to make xml
+        p.send_signal(signal.SIGINT)
+        exit_status = None
+        try:
+            exit_status = p.wait(timeout=5)
+        # try to handle the case where p.wait(timeout=5) times out as well as
+        # otherwise the wait() call in the finally block can potentially hang
+        except subprocess.TimeoutExpired:
+            pass
+        if exit_status is not None:
+            return exit_status
+        else:
+            p.kill()
+        raise
+    except:  # noqa: B001,E722, copied from python core library
+        p.kill()
+        raise
+    finally:
+        # Always call p.wait() to ensure exit
+        p.wait()
+
+def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None):
+    sys.stdout.flush()
+    sys.stderr.flush()
+    # The following cool snippet is copied from Py3 core library subprocess.call
+    # only the with
+    #   1. `except KeyboardInterrupt` block added for SIGINT handling.
+    #   2. In Py2, subprocess.Popen doesn't return a context manager, so we do
+    #      `p.wait()` in a `final` block for the code to be portable.
+    #
+    # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
+    assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
+    p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
+    return wait_for_process(p, timeout=timeout)
+
+
+def retry_shell(
+    command,
+    cwd=None,
+    env=None,
+    stdout=None,
+    stderr=None,
+    timeout=None,
+    retries=1,
+    was_rerun=False,
+) -> tuple[int, bool]:
+    # Returns exicode + whether it was rerun
+    assert (
+        retries >= 0
+    ), f"Expecting non negative number for number of retries, got {retries}"
+    try:
+        exit_code = shell(
+            command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout
+        )
+        if exit_code == 0 or retries == 0:
+            return exit_code, was_rerun
+        print(
+            f"Got exit code {exit_code}, retrying (retries left={retries})",
+            file=stdout,
+            flush=True,
+        )
+    except subprocess.TimeoutExpired:
+        if retries == 0:
+            print(
+                f"Command took >{timeout // 60}min, returning 124",
+                file=stdout,
+                flush=True,
+            )
+            return 124, was_rerun
+        print(
+            f"Command took >{timeout // 60}min, retrying (retries left={retries})",
+            file=stdout,
+            flush=True,
+        )
+    return retry_shell(
+        command,
+        cwd=cwd,
+        env=env,
+        stdout=stdout,
+        stderr=stderr,
+        timeout=timeout,
+        retries=retries - 1,
+        was_rerun=True,
+    )
+
+
+def discover_test_cases_recursively(suite_or_case):
+    if isinstance(suite_or_case, unittest.TestCase):
+        return [suite_or_case]
+    rc = []
+    for element in suite_or_case:
+        print(element)
+        rc.extend(discover_test_cases_recursively(element))
+    return rc
+
+def get_test_names(test_cases):
+    return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
+
+def _print_test_names():
+    suite = unittest.TestLoader().loadTestsFromModule(__main__)
+    test_cases = discover_test_cases_recursively(suite)
+    for name in get_test_names(test_cases):
+        print(name)
+
+def chunk_list(lst, nchunks):
+    return [lst[i::nchunks] for i in range(nchunks)]
+
+# sanitize filename e.g., distributed/pipeline/sync/skip/test_api.py -> distributed.pipeline.sync.skip.test_api
+def sanitize_test_filename(filename):
+    strip_py = re.sub(r'.py$', '', filename)
+    return re.sub('/', r'.', strip_py)
+
+def lint_test_case_extension(suite):
+    succeed = True
+    for test_case_or_suite in suite:
+        test_case = test_case_or_suite
+        if isinstance(test_case_or_suite, unittest.TestSuite):
+            first_test = test_case_or_suite._tests[0] if len(test_case_or_suite._tests) > 0 else None
+            if first_test is not None and isinstance(first_test, unittest.TestSuite):
+                return succeed and lint_test_case_extension(test_case_or_suite)
+            test_case = first_test
+
+        if test_case is not None:
+            if not isinstance(test_case, TestCase):
+                test_class = test_case.id().split('.', 1)[1].split('.')[0]
+                err = "This test class should extend from torch.testing._internal.common_utils.TestCase but it doesn't."
+                print(f"{test_class} - failed. {err}")
+                succeed = False
+    return succeed
+
+
+def get_report_path(argv=None, pytest=False):
+    if argv is None:
+        argv = UNITTEST_ARGS
+    test_filename = sanitize_test_filename(argv[0])
+    test_report_path = TEST_SAVE_XML + LOG_SUFFIX
+    test_report_path = os.path.join(test_report_path, test_filename)
+    if pytest:
+        test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
+        os.makedirs(test_report_path, exist_ok=True)
+        test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
+        return test_report_path
+    os.makedirs(test_report_path, exist_ok=True)
+    return test_report_path
+
+
+def sanitize_pytest_xml(xml_file: str):
+    # pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
+    # consider somehow modifying the XML logger in conftest to do this instead
+    import xml.etree.ElementTree as ET
+    tree = ET.parse(xml_file)
+    for testcase in tree.iter('testcase'):
+        full_classname = testcase.attrib.get("classname")
+        if full_classname is None:
+            continue
+        # The test prefix is optional
+        regex_result = re.search(r"^(test\.)?(?P.*)\.(?P[^\.]*)$", full_classname)
+        if regex_result is None:
+            continue
+        classname = regex_result.group("classname")
+        file = regex_result.group("file").replace(".", "/")
+        testcase.set("classname", classname)
+        testcase.set("file", f"{file}.py")
+    tree.write(xml_file)
+
+
+def get_pytest_test_cases(argv: list[str]) -> list[str]:
+    class TestCollectorPlugin:
+        def __init__(self) -> None:
+            self.tests: list[Any] = []
+
+        def pytest_collection_finish(self, session):
+            for item in session.items:
+                self.tests.append(session.config.cwd_relative_nodeid(item.nodeid))
+
+    test_collector_plugin = TestCollectorPlugin()
+    import pytest
+    pytest.main(
+        [arg for arg in argv if arg != '-vv'] + ['--collect-only', '-qq', '--use-main-module'],
+        plugins=[test_collector_plugin]
+    )
+    return test_collector_plugin.tests
+
+
+def run_tests(argv=None):
+    parse_cmd_line_args()
+    if argv is None:
+        argv = UNITTEST_ARGS
+
+    # import test files.
+    if SLOW_TESTS_FILE:
+        if os.path.exists(SLOW_TESTS_FILE):
+            with open(SLOW_TESTS_FILE) as fp:
+                global slow_tests_dict
+                slow_tests_dict = json.load(fp)
+                # use env vars so pytest-xdist subprocesses can still access them
+                os.environ['SLOW_TESTS_FILE'] = SLOW_TESTS_FILE
+        else:
+            warnings.warn(f'slow test file provided but not found: {SLOW_TESTS_FILE}', stacklevel=2)
+    if DISABLED_TESTS_FILE:
+        if os.path.exists(DISABLED_TESTS_FILE):
+            with open(DISABLED_TESTS_FILE) as fp:
+                global disabled_tests_dict
+                disabled_tests_dict = json.load(fp)
+                os.environ['DISABLED_TESTS_FILE'] = DISABLED_TESTS_FILE
+        else:
+            warnings.warn(f'disabled test file provided but not found: {DISABLED_TESTS_FILE}', stacklevel=2)
+    # Determine the test launch mechanism
+    if TEST_DISCOVER:
+        _print_test_names()
+        return
+
+    # Before running the tests, lint to check that every test class extends from TestCase
+    suite = unittest.TestLoader().loadTestsFromModule(__main__)
+    if not lint_test_case_extension(suite):
+        sys.exit(1)
+
+    if SHOWLOCALS:
+        argv = [
+            argv[0],
+            *(["--showlocals", "--tb=long", "--color=yes"] if USE_PYTEST else ["--locals"]),
+            *argv[1:],
+        ]
+
+    if TEST_IN_SUBPROCESS:
+        other_args = []
+        if DISABLED_TESTS_FILE:
+            other_args.append("--import-disabled-tests")
+        if SLOW_TESTS_FILE:
+            other_args.append("--import-slow-tests")
+        if USE_PYTEST:
+            other_args.append("--use-pytest")
+        if RERUN_DISABLED_TESTS:
+            other_args.append("--rerun-disabled-tests")
+        if TEST_SAVE_XML:
+            other_args += ['--save-xml', TEST_SAVE_XML]
+
+        test_cases = (
+            get_pytest_test_cases(argv) if USE_PYTEST else
+            [case.id().split('.', 1)[1] for case in discover_test_cases_recursively(suite)]
+        )
+
+        failed_tests = []
+
+        for test_case_full_name in test_cases:
+
+            cmd = (
+                [sys.executable] + [argv[0]] + other_args + argv[1:] +
+                (["--pytest-single-test"] if USE_PYTEST else []) +
+                [test_case_full_name]
+            )
+            string_cmd = " ".join(cmd)
+
+            timeout = None if RERUN_DISABLED_TESTS else 15 * 60
+
+            exitcode, _ = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1)
+
+            if exitcode != 0:
+                # This is sort of hacky, but add on relevant env variables for distributed tests.
+                if 'TestDistBackendWithSpawn' in test_case_full_name:
+                    backend = os.environ.get("BACKEND", "")
+                    world_size = os.environ.get("WORLD_SIZE", "")
+                    env_prefix = f"BACKEND={backend} WORLD_SIZE={world_size}"
+                    string_cmd = env_prefix + " " + string_cmd
+                # Log the command to reproduce the failure.
+                print(f"Test exited with non-zero exitcode {exitcode}. Command to reproduce: {string_cmd}")
+                failed_tests.append(test_case_full_name)
+
+            assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
+                len(failed_tests), '\n\t'.join(failed_tests))
+
+    elif RUN_PARALLEL > 1:
+        test_cases = discover_test_cases_recursively(suite)
+        test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
+        processes = []
+        for i in range(RUN_PARALLEL):
+            command = [sys.executable] + argv + [f'--log-suffix=-shard-{i + 1}'] + test_batches[i]
+            processes.append(subprocess.Popen(command, universal_newlines=True))
+        failed = False
+        for p in processes:
+            failed |= wait_for_process(p) != 0
+        assert not failed, "Some test shards have failed"
+    elif USE_PYTEST:
+        pytest_args = argv + ["--use-main-module"]
+        test_report_path = ""
+        if TEST_SAVE_XML:
+            test_report_path = get_report_path(pytest=True)
+            print(f'Test results will be stored in {test_report_path}')
+            pytest_args.append(f'--junit-xml-reruns={test_report_path}')
+        if PYTEST_SINGLE_TEST:
+            pytest_args = PYTEST_SINGLE_TEST + pytest_args[1:]
+
+        import pytest
+        os.environ["NO_COLOR"] = "1"
+        exit_code = pytest.main(args=pytest_args)
+        if TEST_SAVE_XML:
+            sanitize_pytest_xml(test_report_path)
+
+        # exitcode of 5 means no tests were found, which happens since some test configs don't
+        # run tests from certain files
+        sys.exit(0 if exit_code == 5 else exit_code)
+    elif TEST_SAVE_XML:
+        # import here so that non-CI doesn't need xmlrunner installed
+        import xmlrunner  # type: ignore[import]
+        from xmlrunner.result import _XMLTestResult  # type: ignore[import]
+
+        class XMLTestResultVerbose(_XMLTestResult):
+            """
+            Adding verbosity to test outputs:
+            by default test summary prints 'skip',
+            but we want to also print the skip reason.
+            GH issue: https://github.com/pytorch/pytorch/issues/69014
+
+            This works with unittest_xml_reporting<=3.2.0,>=2.0.0
+            (3.2.0 is latest at the moment)
+            """
+
+            def addSkip(self, test, reason):
+                super().addSkip(test, reason)
+                for c in self.callback.__closure__:
+                    if isinstance(c.cell_contents, str) and c.cell_contents == 'skip':
+                        # this message is printed in test summary;
+                        # it stands for `verbose_str` captured in the closure
+                        c.cell_contents = f"skip: {reason}"
+
+            def printErrors(self) -> None:
+                super().printErrors()
+                self.printErrorList("XPASS", self.unexpectedSuccesses)
+        test_report_path = get_report_path()
+        verbose = '--verbose' in argv or '-v' in argv
+        if verbose:
+            print(f'Test results will be stored in {test_report_path}')
+        unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
+            output=test_report_path,
+            verbosity=2 if verbose else 1,
+            resultclass=XMLTestResultVerbose))
+    elif REPEAT_COUNT > 1:
+        for _ in range(REPEAT_COUNT):
+            if not unittest.main(exit=False, argv=argv).result.wasSuccessful():
+                sys.exit(-1)
+    else:
+        unittest.main(argv=argv)
+
+IS_LINUX = sys.platform == "linux"
+IS_WINDOWS = sys.platform == "win32"
+IS_MACOS = sys.platform == "darwin"
+IS_PPC = platform.machine() == "ppc64le"
+IS_X86 = platform.machine() in ('x86_64', 'i386')
+IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
+IS_S390X = platform.machine() == "s390x"
+
+def is_avx512_vnni_supported():
+    if sys.platform != 'linux':
+        return False
+    with open("/proc/cpuinfo", encoding="ascii") as f:
+        lines = f.read()
+    return "vnni" in lines
+
+IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported()
+
+if IS_WINDOWS:
+    @contextmanager
+    def TemporaryFileName(*args, **kwargs):
+        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+        # close the file after creation and try to remove it manually
+        if 'delete' in kwargs:
+            if kwargs['delete'] is not False:
+                raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.")
+        else:
+            kwargs['delete'] = False
+        f = tempfile.NamedTemporaryFile(*args, **kwargs)  # noqa:SIM115
+        try:
+            f.close()
+            yield f.name
+        finally:
+            os.unlink(f.name)
+else:
+    @contextmanager  # noqa: T484
+    def TemporaryFileName(*args, **kwargs):
+        with tempfile.NamedTemporaryFile(*args, **kwargs) as f:
+            yield f.name
+
+if IS_WINDOWS:
+    @contextmanager
+    def TemporaryDirectoryName(suffix=None):
+        # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely,
+        # so we first create the directory using mkdtemp and then remove it manually
+        try:
+            dir_name = tempfile.mkdtemp(suffix=suffix)
+            yield dir_name
+        finally:
+            shutil.rmtree(dir_name)
+else:
+    @contextmanager  # noqa: T484
+    def TemporaryDirectoryName(suffix=None):
+        with tempfile.TemporaryDirectory(suffix=suffix) as d:
+            yield d
+
+
+def is_privateuse1_backend_available():
+    privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
+    privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None)
+    return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available()
+
+
+def make_lazy_class(cls):
+
+    def lazy_init(self, cb):
+        self._cb = cb
+        self._value = None
+
+    cls.__init__ = lazy_init
+
+    for basename in [
+        "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
+        "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
+        "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
+    ]:
+        name = f"__{basename}__"
+
+        def inner_wrapper(name):
+            use_operator = basename not in ("bool", "int")
+
+            def wrapped(self, *args, **kwargs):
+                if self._cb is not None:
+                    self._value = self._cb()
+                    self._cb = None
+                if not use_operator:
+                    return getattr(self._value, name)(*args, **kwargs)
+                else:
+                    return getattr(operator, name)(self._value, *args, **kwargs)
+            return wrapped
+
+        setattr(cls, name, inner_wrapper(name))
+
+    return cls
+
+
+@make_lazy_class
+class LazyVal:
+    pass
+
+
+IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
+
+TEST_NUMPY = _check_module_exists('numpy')
+TEST_FAIRSEQ = _check_module_exists('fairseq')
+TEST_SCIPY = _check_module_exists('scipy')
+TEST_MKL = torch.backends.mkl.is_available()
+TEST_ACL = torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_acl_supported()
+TEST_MPS = torch.backends.mps.is_available()
+MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
+TEST_XPU = torch.xpu.is_available()
+TEST_HPU = bool(hasattr(torch, "hpu") and torch.hpu.is_available())
+TEST_CUDA = torch.cuda.is_available()
+TEST_ACCELERATOR = LazyVal(lambda: torch.accelerator.is_available())  # type: ignore[call-arg]
+TEST_MULTIACCELERATOR = LazyVal(lambda: torch.accelerator.device_count() > 1)  # type: ignore[call-arg]
+custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
+TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
+TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
+TEST_NUMBA = _check_module_exists('numba')
+TEST_TRANSFORMERS = _check_module_exists('transformers')
+TEST_DILL = _check_module_exists('dill')
+
+TEST_LIBROSA = _check_module_exists('librosa') and not IS_ARM64
+
+TEST_OPT_EINSUM = _check_module_exists('opt_einsum')
+
+TEST_Z3 = _check_module_exists('z3')
+
+def split_if_not_empty(x: str):
+    return x.split(",") if len(x) != 0 else []
+
+NOTEST_CPU = "cpu" in split_if_not_empty(os.getenv('PYTORCH_TESTING_DEVICE_EXCEPT_FOR', ''))
+
+skipIfNoDill = unittest.skipIf(not TEST_DILL, "no dill")
+
+
+NO_MULTIPROCESSING_SPAWN: bool = False
+TEST_WITH_ASAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_ASAN",
+    env_var="PYTORCH_TEST_WITH_ASAN",
+)
+TEST_WITH_DEV_DBG_ASAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_DEV_DBG_ASAN",
+    env_var="PYTORCH_TEST_WITH_DEV_DBG_ASAN",
+)
+TEST_WITH_TSAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TSAN",
+    env_var="PYTORCH_TEST_WITH_TSAN",
+)
+TEST_WITH_UBSAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_UBSAN",
+    env_var="PYTORCH_TEST_WITH_UBSAN",
+)
+TEST_WITH_ROCM: bool = TestEnvironment.def_flag(
+    "TEST_WITH_ROCM",
+    env_var="PYTORCH_TEST_WITH_ROCM",
+)
+TEST_WITH_MTIA: bool = TestEnvironment.def_flag(
+    "TEST_WITH_MTIA",
+    env_var="PYTORCH_TEST_WITH_MTIA",
+)
+
+# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
+# See #64427
+TEST_WITH_MIOPEN_SUGGEST_NHWC = os.getenv('PYTORCH_MIOPEN_SUGGEST_NHWC', '0') == '1'
+# Enables tests that are slow to run (disabled by default)
+TEST_WITH_SLOW: bool = TestEnvironment.def_flag(
+    "TEST_WITH_SLOW",
+    env_var="PYTORCH_TEST_WITH_SLOW",
+)
+
+# Disables non-slow tests (these tests enabled by default)
+# This is usually used in conjunction with TEST_WITH_SLOW to
+# run *only* slow tests.  (I could have done an enum, but
+# it felt a little awkward.
+TEST_SKIP_FAST: bool = TestEnvironment.def_flag(
+    "TEST_SKIP_FAST",
+    env_var="PYTORCH_TEST_SKIP_FAST",
+)
+
+# Enables crossref tests, in addition to standard tests which
+# are being run.  crossref tests work by installing a torch
+# function mode that runs extra compute alongside the regular
+# computation that happens with the test.  After both computations
+# are done, we cross-reference them (thus the name) to check for
+# correction, before throwing out the extra compute and proceeding
+# as we had before.  By default, we don't run these tests.
+TEST_WITH_CROSSREF: bool = TestEnvironment.def_flag(
+    "TEST_WITH_CROSSREF",
+    env_var="PYTORCH_TEST_WITH_CROSSREF",
+)
+
+TEST_SKIP_CUDAGRAPH: bool = TestEnvironment.def_flag(
+    "TEST_SKIP_CUDAGRAPH",
+    env_var="PYTORCH_TEST_SKIP_CUDAGRAPH",
+)
+TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
+    torch.version.cuda or
+    (torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
+)
+
+TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
+
+TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and (
+    torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12
+)
+
+if TEST_CUDA_PYTHON_BINDINGS:
+    def cuda_python_error_check(function_call_output):
+        """Makes calls to cuda-python's cuda runtime functions more
+        pythonic by throwing an exception if they return a status
+        which is not cudaSuccess
+        """
+        import cuda.bindings  # type: ignore[import]
+
+        error, *others = function_call_output
+        if error != cuda.bindings.runtime.cudaError_t.cudaSuccess:
+            raise ValueError(f"CUDA failure! {error}")
+        else:
+            return tuple(others)
+else:
+    cuda_python_error_check = None  # type: ignore[assignment]
+
+def allocator_option_enabled_fn(allocator_config, _, option):
+    if allocator_config is None:
+        return False
+    allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config]
+    mapping = dict([var.split(':') for var in allocator_config])
+
+    if option in mapping and mapping[option] == 'True':
+        return True
+    else:
+        return False
+
+EXPANDABLE_SEGMENTS: bool = TestEnvironment.def_flag(
+    "EXPANDABLE_SEGMENTS",
+    env_var="PYTORCH_CUDA_ALLOC_CONF",
+    enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'),
+)
+
+if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
+    num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2"))
+    gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30
+    # other libraries take up about a little under 1 GB of space per process
+    torch.cuda.set_per_process_memory_fraction(round((gb_available - num_procs * .85) / gb_available / num_procs, 2))
+
+requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "Requires CUDA")
+
+def skipIfCrossRef(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_WITH_CROSSREF:
+            raise unittest.SkipTest("test doesn't currently with crossref")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+class CrossRefMode(torch.overrides.TorchFunctionMode):
+    def __torch_function__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+        r = func(*args, **kwargs)
+        return r
+
+# Run PyTorch tests with TorchDynamo
+TEST_WITH_TORCHINDUCTOR: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TORCHINDUCTOR",
+    env_var="PYTORCH_TEST_WITH_INDUCTOR",
+)
+# AOT_EAGER not tested in ci, useful for debugging
+TEST_WITH_AOT_EAGER: bool = TestEnvironment.def_flag(
+    "TEST_WITH_AOT_EAGER",
+    env_var="PYTORCH_TEST_WITH_AOT_EAGER",
+)
+TEST_WITH_TORCHDYNAMO: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TORCHDYNAMO",
+    env_var="PYTORCH_TEST_WITH_DYNAMO",
+    implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER,
+)
+TEST_WITHOUT_COMPILED_AUTOGRAD: bool = TestEnvironment.def_flag(
+    "TEST_WITHOUT_COMPILED_AUTOGRAD",
+    env_var="PYTORCH_TEST_WITHOUT_COMPILED_AUTOGRAD",
+)
+
+if TEST_WITH_TORCHDYNAMO:
+    import torch._dynamo
+    # Do not spend time on helper functions that are called with different inputs
+    torch._dynamo.config.accumulated_recompile_limit = 64
+    # Do not log compilation metrics from unit tests
+    torch._dynamo.config.log_compilation_metrics = False
+    # Silence 3.13.0 guard performance warnings
+    torch._dynamo.config.issue_3_13_0_warning = False
+    if TEST_WITH_TORCHINDUCTOR:
+        import torch._inductor.config
+        torch._inductor.config.fallback_random = True
+    else:
+        # only dynamo for now
+        torch._dynamo.config.compiled_autograd = not TEST_WITHOUT_COMPILED_AUTOGRAD
+
+
+# seems like this is only used in test/torch_np
+def xpassIfTorchDynamo_np(func):
+    # numpy 2.0+ is causing issues
+    if TEST_WITH_TORCHDYNAMO and np.__version__[0] == '2':
+        return unittest.skip("skipping numpy 2.0+ dynamo-wrapped test")(func)
+    return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func)
+
+
+def xfailIfACL(func):
+    return unittest.expectedFailure(func) if TEST_ACL else func
+
+
+def xfailIfTorchDynamo(func):
+    return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func
+
+
+def xfailIfPy312Plus(func):
+    return unittest.expectedFailure(func) if sys.version_info >= (3, 12) else func
+
+
+def xfailIfLinux(func):
+    return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func
+
+
+def xfailIfWindows(func):
+    return unittest.expectedFailure(func) if IS_WINDOWS else func
+
+
+def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
+    """
+    Usage:
+    @skipIfTorchDynamo(msg)
+    def test_blah(self):
+        ...
+    """
+    assert isinstance(msg, str), "Are you using skipIfTorchDynamo correctly?"
+
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if TEST_WITH_TORCHDYNAMO:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if TEST_WITH_TORCHDYNAMO:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
+                        condition=TEST_WITH_TORCHINDUCTOR):
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if condition:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if condition:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def runWithoutCompiledAutograd(msg="test doesn't currently work with compiled autograd"):
+    """
+    Usage:
+    @runWithoutCompiledAutograd(msg)
+    def test_blah(self):
+        ...
+    """
+    assert isinstance(msg, str)
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            with torch._dynamo.compiled_autograd._disable():
+                func(*args, **kwargs)
+        return wrapper
+
+    return decorator
+
+def serialTest(condition=True):
+    """
+    Decorator for running tests serially.  Requires pytest
+    """
+    # If one apply decorator directly condition will be callable
+    # And test will essentially be essentially skipped, which is undesirable
+    assert type(condition) is bool
+
+    def decorator(fn):
+        if has_pytest and condition:
+            return pytest.mark.serial(fn)
+        return fn
+    return decorator
+
+def unMarkDynamoStrictTest(cls=None):
+    def decorator(cls):
+        cls.dynamo_strict = False
+        return cls
+
+    if cls is None:
+        return decorator
+    else:
+        return decorator(cls)
+
+
+def markDynamoStrictTest(cls_or_func=None, nopython=False):
+    """
+    Marks the test as 'strict'. In strict mode, we reset before and after the
+    test, and run without suppress errors.
+
+    Args:
+    - nopython: if we should run torch._dynamo.optimize with nopython={True/False}.
+    """
+    def decorator(cls_or_func):
+        if inspect.isclass(cls_or_func):
+            cls_or_func.dynamo_strict = True
+            cls_or_func.dynamo_strict_nopython = nopython
+            return cls_or_func
+
+        fn = cls_or_func
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            torch._dynamo.reset()
+            with unittest.mock.patch("torch._dynamo.config.suppress_errors", False):
+                fn(*args, **kwargs)
+            torch._dynamo.reset()
+        return wrapper
+
+    if cls_or_func is None:
+        return decorator
+    else:
+        return decorator(cls_or_func)
+
+
+def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
+    return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
+
+def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                assert GRAPH_EXECUTOR
+                if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+
+    return decorator
+
+
+def make_dynamo_test(
+    fn: Optional[Callable[..., Any]] = None
+) -> Callable[..., Any]:
+    """
+    Decorator function to create a dynamo test case. A function annotate with
+    this decorator takes as input a unittest object.
+    """
+    from torch._dynamo.testing import CompileCounter, reset, optimize_assert
+    if fn is None:
+        return lambda fn: make_dynamo_test(fn)
+
+    def standard_test(
+        self: Any,
+        fn: Callable[..., Any],
+        kwargs,
+    ) -> None:
+        def dummy() -> None:
+            fn(self, **kwargs)
+
+        actual = CompileCounter()
+
+        dummy()
+        reset()
+        opt_fn = optimize_assert(actual)(dummy)
+        opt_fn()
+        reset()
+
+    @functools.wraps(fn)
+    def test_fn(self: Any, **kwargs) -> None:
+        return standard_test(
+            self,
+            fn=fn,
+            kwargs=kwargs,
+        )
+
+    return test_fn
+
+
+# Run PyTorch tests with translation validation on.
+TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'
+
+if TEST_WITH_TV:
+    torch.fx.experimental._config.translation_validation = True
+
+# Determine whether to enable cuda memory leak check.
+# CUDA mem leak check is expensive and thus we don't want to execute it on every
+# test case / configuration.
+# If this is True then CUDA memory leak checks are skipped. If this is false
+#   then CUDA memory leak checks are performed.
+# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135
+TEST_CUDA_MEM_LEAK_CHECK: bool = TestEnvironment.def_flag(
+    "TEST_CUDA_MEM_LEAK_CHECK",
+    env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK",
+)
+
+
+# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
+numpy_to_torch_dtype_dict = {
+    np.bool_      : torch.bool,
+    np.uint8      : torch.uint8,
+    np.uint16     : torch.uint16,
+    np.uint32     : torch.uint32,
+    np.uint64     : torch.uint64,
+    np.int8       : torch.int8,
+    np.int16      : torch.int16,
+    np.int32      : torch.int32,
+    np.int64      : torch.int64,
+    np.float16    : torch.float16,
+    np.float32    : torch.float32,
+    np.float64    : torch.float64,
+    np.complex64  : torch.complex64,
+    np.complex128 : torch.complex128
+}
+
+
+# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
+# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
+# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
+def numpy_to_torch_dtype(np_dtype):
+    try:
+        return numpy_to_torch_dtype_dict[np_dtype]
+    except KeyError:
+        return numpy_to_torch_dtype_dict[np_dtype.type]
+
+
+def has_corresponding_torch_dtype(np_dtype):
+    try:
+        numpy_to_torch_dtype(np_dtype)
+        return True
+    except KeyError:
+        return False
+
+
+if IS_WINDOWS:
+    # Size of `np.intc` is platform defined.
+    # It is returned by functions like `bitwise_not`.
+    # On Windows `int` is 32-bit
+    # https://docs.microsoft.com/en-us/cpp/cpp/data-type-ranges?view=msvc-160
+    numpy_to_torch_dtype_dict[np.intc] = torch.int
+
+# Dict of torch dtype -> NumPy dtype
+torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
+torch_to_numpy_dtype_dict.update({
+    torch.bfloat16: np.float32,
+    torch.complex32: np.complex64
+})
+
+def skipIfNNModuleInlined(
+    msg="test doesn't currently work with nn module inlining",
+    condition=torch._dynamo.config.inline_inbuilt_nn_modules,
+):
+    def decorator(fn):
+        if not isinstance(fn, type):
+
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if condition:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+
+            return wrapper
+
+        assert isinstance(fn, type)
+        if condition:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
+    def dec_fn(fn):
+        reason = f"skipIfRocm: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if TEST_WITH_ROCM:
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def getRocmArchName(device_index: int = 0):
+    return torch.cuda.get_device_properties(device_index).gcnArchName
+
+def isRocmArchAnyOf(arch: tuple[str, ...]):
+    rocmArch = getRocmArchName()
+    return any(x in rocmArch for x in arch)
+
+def skipIfRocmArch(arch: tuple[str, ...]):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM and isRocmArchAnyOf(arch):
+                reason = f"skipIfRocm: test skipped on {arch}"
+                raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def runOnRocm(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_WITH_ROCM:
+            fn(*args, **kwargs)
+        else:
+            raise unittest.SkipTest("test currently only works on the ROCm stack")
+    return wrapper
+
+def runOnRocmArch(arch: tuple[str, ...]):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM and not isRocmArchAnyOf(arch):
+                reason = f"skipIfRocm: test only runs on {arch}"
+                raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def xfailIfS390X(func):
+    return unittest.expectedFailure(func) if IS_S390X else func
+
+def xfailIf(condition):
+    def wrapper(func):
+        if condition:
+            return unittest.expectedFailure(func)
+        else:
+            return func
+    return wrapper
+
+def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
+    def dec_fn(fn):
+        reason = f"skipIfXpu: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if TEST_XPU:
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def skipIfMPS(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_MPS:
+            raise unittest.SkipTest("test doesn't currently work with MPS")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+
+def skipIfHpu(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_HPU:
+            raise unittest.SkipTest("test doesn't currently work with HPU")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def getRocmVersion() -> tuple[int, int]:
+    from torch.testing._internal.common_cuda import _get_torch_rocm_version
+    rocm_version = _get_torch_rocm_version()
+    return (rocm_version[0], rocm_version[1])
+
+# Skips a test on CUDA if ROCm is available and its version is lower than requested.
+def skipIfRocmVersionLessThan(version=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM:
+                rocm_version_tuple = getRocmVersion()
+                if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
+                    reason = f"ROCm {rocm_version_tuple} is available but {version} required"
+                    raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def skipIfNotMiopenSuggestNHWC(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_MIOPEN_SUGGEST_NHWC:
+            raise unittest.SkipTest("test doesn't currently work without MIOpen NHWC activation")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfWindows(func=None, *, msg="test doesn't currently work on the Windows stack"):
+    def dec_fn(fn):
+        reason = f"skipIfWindows: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if IS_WINDOWS:  # noqa: F821
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def skipIfWindowsXPU(func=None, *, msg="test doesn't currently work on the Windows stack"):
+    def dec_fn(fn):
+        reason = f"skipIfWindowsXPU: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if IS_WINDOWS and torch.xpu.is_available():  # noqa: F821
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def requires_cuda_p2p_access():
+    cuda_p2p_access_available = (
+        torch.cuda.is_available()
+        and torch.cuda.get_device_capability() >= (8, 0)
+        and torch.cuda.device_count() >= 2
+    )
+    num_devices = torch.cuda.device_count()
+    for i in range(num_devices - 1):
+        for j in range(i + 1, num_devices):
+            if not torch.cuda.can_device_access_peer(i, j):
+                cuda_p2p_access_available = False
+                break
+        if not cuda_p2p_access_available:
+            break
+
+    return skip_but_pass_in_sandcastle_if(
+        not cuda_p2p_access_available,
+        "cuda p2p access is not available",
+    )
+
+# Reverts the linalg backend back to default to make sure potential failures in one
+# test do not affect other tests
+def setLinalgBackendsToDefaultFinally(fn):
+    @wraps(fn)
+    def _fn(*args, **kwargs):
+        _preferred_backend = torch.backends.cuda.preferred_linalg_library()
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch.backends.cuda.preferred_linalg_library(_preferred_backend)
+    return _fn
+
+
+# Reverts the blas backend back to default to make sure potential failures in one
+# test do not affect other tests
+def setBlasBackendsToDefaultFinally(fn):
+    @wraps(fn)
+    def _fn(*args, **kwargs):
+        _preferred_backend = torch.backends.cuda.preferred_blas_library()
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch.backends.cuda.preferred_blas_library(_preferred_backend)
+    return _fn
+
+
+# Context manager for setting deterministic flag and automatically
+# resetting it to its original value
+class DeterministicGuard:
+    def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True):
+        self.deterministic = deterministic
+        self.warn_only = warn_only
+        self.fill_uninitialized_memory = fill_uninitialized_memory
+
+    @classmethod
+    def _current_state(cls):
+        return cls(
+            torch.are_deterministic_algorithms_enabled(),
+            warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
+            fill_uninitialized_memory=torch.utils.deterministic.fill_uninitialized_memory,  # type: ignore[attr-defined]
+        )
+
+    def _update(self):
+        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
+        torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory  # type: ignore[attr-defined]
+
+    def __enter__(self):
+        self._restore = self._current_state()
+        self._update()
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        self._restore._update()
+
+class AlwaysWarnTypedStorageRemoval:
+    def __init__(self, always_warn):
+        assert isinstance(always_warn, bool)
+        self.always_warn = always_warn
+
+    def __enter__(self):
+        self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal()
+        torch.storage._set_always_warn_typed_storage_removal(self.always_warn)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore)
+
+# Context manager for setting cuda sync debug mode and reset it
+# to original value
+# we are not exposing it to the core because sync debug mode is
+# global and thus not thread safe
+class CudaSyncGuard:
+    def __init__(self, sync_debug_mode):
+        self.mode = sync_debug_mode
+
+    def __enter__(self):
+        self.debug_mode_restore = torch.cuda.get_sync_debug_mode()
+        torch.cuda.set_sync_debug_mode(self.mode)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.cuda.set_sync_debug_mode(self.debug_mode_restore)
+
+# Context manager for setting torch.__future__.set_swap_module_params_on_conversion
+# and automatically resetting it to its original value
+class SwapTensorsGuard:
+    def __init__(self, use_swap_tensors):
+        self.use_swap_tensors = use_swap_tensors
+
+    def __enter__(self):
+        self.swap_tensors_restore = torch.__future__.get_swap_module_params_on_conversion()
+        if self.use_swap_tensors is not None:
+            torch.__future__.set_swap_module_params_on_conversion(self.use_swap_tensors)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.__future__.set_swap_module_params_on_conversion(self.swap_tensors_restore)
+
+# This decorator can be used for API tests that call
+# torch.use_deterministic_algorithms().  When the test is finished, it will
+# restore the previous deterministic flag setting.
+#
+# If CUDA >= 10.2, this will set the environment variable
+# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that
+# setting is not thrown during the test unless the test changes that variable
+# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be
+# restored once the test is finished.
+#
+# Note that if a test requires CUDA to actually register the changed
+# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because
+# CUDA only checks the variable when the runtime initializes. Tests can be
+# run inside a subprocess like so:
+#
+#   import subprocess, sys, os
+#   script = '''
+#   # Test code should go here
+#   '''
+#   try:
+#       subprocess.check_output(
+#           [sys.executable, '-c', script],
+#           stderr=subprocess.STDOUT,
+#           cwd=os.path.dirname(os.path.realpath(__file__)),
+#           env=os.environ.copy())
+#   except subprocess.CalledProcessError as e:
+#       error_message = e.output.decode('utf-8')
+#       # Handle exceptions raised by the subprocess here
+#
+def wrapDeterministicFlagAPITest(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        with DeterministicGuard(
+                torch.are_deterministic_algorithms_enabled(),
+                warn_only=torch.is_deterministic_algorithms_warn_only_enabled()):
+            class CuBLASConfigGuard:
+                cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
+
+                def __enter__(self):
+                    self.cublas_config_restore = os.environ.get(self.cublas_var_name)
+                    os.environ[self.cublas_var_name] = ':4096:8'
+
+                def __exit__(self, exception_type, exception_value, traceback):
+                    cur_cublas_config = os.environ.get(self.cublas_var_name)
+                    if self.cublas_config_restore is None:
+                        if cur_cublas_config is not None:
+                            del os.environ[self.cublas_var_name]
+                    else:
+                        os.environ[self.cublas_var_name] = self.cublas_config_restore
+            with CuBLASConfigGuard():
+                fn(*args, **kwargs)
+    return wrapper
+
+# This decorator can be used for API tests that want to safely call
+# torch.__future__.set_swap_module_params_on_conversion.  `swap` can be set to
+# True, False or None where None indicates that the context manager does not
+# set the flag. When the test is finished, it will restore the previous swap
+# flag setting.
+def wrapSwapTensorsTest(swap=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            with SwapTensorsGuard(swap):
+                fn(*args, **kwargs)
+        return wrapper
+    return dec_fn
+
+# test parametrizer for swapping
+class swap(_TestParametrizer):
+    def __init__(self, swap_values):
+        super().__init__()
+        self.swap_values = swap_values
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        for swap in self.swap_values:
+            yield wrapSwapTensorsTest(swap)(test), f'swap_{swap}', {}, lambda _: []
+
+def skipIfCompiledWithoutNumpy(fn):
+    # Even if the numpy module is present, if `USE_NUMPY=0` is used during the
+    # build, numpy tests will fail
+    numpy_support = TEST_NUMPY
+    if numpy_support:
+        try:
+            # The numpy module is present, verify that PyTorch is compiled with
+            # numpy support
+            torch.from_numpy(np.array([2, 2]))
+        except RuntimeError:
+            numpy_support = False
+
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not numpy_support:
+            raise unittest.SkipTest("PyTorch was compiled without numpy support")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def _test_function(fn, device):
+    def run_test_function(self):
+        return fn(self, device)
+    return run_test_function
+
+def skipIfNoXNNPACK(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch.backends.xnnpack.enabled:  # type: ignore[attr-defined]
+            raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfNoLapack(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch._C.has_lapack:
+            raise unittest.SkipTest('PyTorch compiled without Lapack')
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfNotRegistered(op_name, message):
+    """Wraps the decorator to hide the import of the `core`.
+
+    Args:
+        op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`.
+        message: message to fail with.
+
+    Usage:
+        @skipIfNotRegistered('MyOp', 'MyOp is not linked!')
+            This will check if 'MyOp' is in the caffe2.python.core
+    """
+    return unittest.skip("Pytorch is compiled without Caffe2")
+
+def skipIfNoSciPy(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_SCIPY:
+            raise unittest.SkipTest("test require SciPy, but SciPy not found")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skip_if_pytest(fn):
+    @wraps(fn)
+    def wrapped(*args, **kwargs):
+        if "PYTEST_CURRENT_TEST" in os.environ:
+            raise unittest.SkipTest("does not work under pytest")
+        return fn(*args, **kwargs)
+
+    return wrapped
+
+def skipIfNoXPU(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_XPU:
+            raise unittest.SkipTest("test required PyTorched compiled with XPU")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def slowTest(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_SLOW:
+            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+        else:
+            fn(*args, **kwargs)
+    wrapper.__dict__['slow_test'] = True
+    return wrapper
+
+
+def slowTestIf(condition):
+    return slowTest if condition else lambda fn: fn
+
+
+def skipCUDAMemoryLeakCheckIf(condition):
+    def dec(fn):
+        if getattr(fn, '_do_cuda_memory_leak_check', True):  # if current True
+            fn._do_cuda_memory_leak_check = not condition
+        return fn
+    return dec
+
+def skipCUDANonDefaultStreamIf(condition):
+    def dec(fn):
+        if getattr(fn, '_do_cuda_non_default_stream', True):  # if current True
+            fn._do_cuda_non_default_stream = not condition
+        return fn
+    return dec
+
+def suppress_warnings(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            fn(*args, **kwargs)
+    return wrapper
+
+
+def to_gpu(obj, type_map=None):
+    if type_map is None:
+        type_map = {}
+    if isinstance(obj, torch.Tensor):
+        assert obj.is_leaf
+        t = type_map.get(obj.dtype, obj.dtype)
+        with torch.no_grad():
+            res = obj.to(dtype=t, device="cuda", copy=True)
+            res.requires_grad = obj.requires_grad
+        return res
+    elif torch.is_storage(obj):
+        return obj.new().resize_(obj.size()).copy_(obj)  # type: ignore[attr-defined, union-attr]
+    elif isinstance(obj, list):
+        return [to_gpu(o, type_map) for o in obj]
+    elif isinstance(obj, tuple):
+        return tuple(to_gpu(o, type_map) for o in obj)
+    else:
+        return deepcopy(obj)
+
+
+def get_function_arglist(func):
+    return inspect.getfullargspec(func).args
+
+
+def set_rng_seed(seed=None):
+    if seed is None:
+        seed = SEED
+    torch.manual_seed(seed)
+    random.seed(seed)
+    if TEST_NUMPY:
+        np.random.seed(seed)
+
+
+@contextlib.contextmanager
+def set_default_dtype(dtype):
+    saved_dtype = torch.get_default_dtype()
+    torch.set_default_dtype(dtype)
+    try:
+        yield
+    finally:
+        torch.set_default_dtype(saved_dtype)
+
+@contextlib.contextmanager
+def set_default_tensor_type(tensor_type):
+    saved_tensor_type = torch.tensor([]).type()
+    torch.set_default_tensor_type(tensor_type)
+    try:
+        yield
+    finally:
+        torch.set_default_tensor_type(saved_tensor_type)
+
+def iter_indices(tensor):
+    if tensor.dim() == 0:
+        return range(0)
+    if tensor.dim() == 1:
+        return range(tensor.size(0))
+    return product(*(range(s) for s in tensor.size()))
+
+
+def is_iterable(obj):
+    try:
+        iter(obj)
+        return True
+    except TypeError:
+        return False
+
+
+def is_iterable_of_tensors(iterable, include_empty=False):
+    """ Returns True if iterable is an iterable of tensors and False o.w.
+
+        If the iterable is empty, the return value is :attr:`include_empty`
+    """
+    # Tensor itself is iterable so we check this first
+    if isinstance(iterable, torch.Tensor):
+        return False
+
+    try:
+        if len(iterable) == 0:
+            return include_empty
+
+        for t in iter(iterable):
+            if not isinstance(t, torch.Tensor):
+                return False
+
+    except TypeError:
+        return False
+
+    return True
+
+
+class CudaNonDefaultStream:
+    def __enter__(self):
+        # Before starting CUDA test save currently active streams on all
+        # CUDA devices and set new non default streams to all CUDA devices
+        # to ensure CUDA tests do not use default stream by mistake.
+        beforeDevice = torch.cuda.current_device()
+        self.beforeStreams = []
+        for d in range(torch.cuda.device_count()):
+            self.beforeStreams.append(torch.cuda.current_stream(d))
+            deviceStream = torch.cuda.Stream(device=d)
+            self.beforeStreams[-1].synchronize()
+            torch._C._cuda_setStream(stream_id=deviceStream.stream_id,
+                                     device_index=deviceStream.device_index,
+                                     device_type=deviceStream.device_type)
+        torch._C._cuda_setDevice(beforeDevice)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        # After completing CUDA test load previously active streams on all
+        # CUDA devices.
+        beforeDevice = torch.cuda.current_device()
+        for d in range(torch.cuda.device_count()):
+            torch._C._cuda_setStream(stream_id=self.beforeStreams[d].stream_id,
+                                     device_index=self.beforeStreams[d].device_index,
+                                     device_type=self.beforeStreams[d].device_type)
+        torch._C._cuda_setDevice(beforeDevice)
+
+class CudaMemoryLeakCheck:
+    def __init__(self, testcase, name=None):
+        self.name = testcase.id() if name is None else name
+        self.testcase = testcase
+
+        # initialize context & RNG to prevent false positive detections
+        # when the test is the first to initialize those
+        from torch.testing._internal.common_cuda import initialize_cuda_context_rng
+        initialize_cuda_context_rng()
+
+    # Stores CUDA memory data provided by PyTorch's caching allocator and
+    #   the CUDA driver.
+    #
+    # NOTE: The undocumented torch.cuda.mem_get_info() returns
+    #   (#free bytes, #total bytes available) on the GPU
+    def __enter__(self):
+        self.caching_allocator_befores = []
+        self.driver_befores = []
+
+        # Performs a gc if required (required if any CUDA memory is held)
+        num_devices = torch.cuda.device_count()
+        for i in range(num_devices):
+            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+            # NOTE: gc is based exclusively on caching allocator memory
+            #   because the driver will always have some bytes in use (context size?)
+            if caching_allocator_mem_allocated > 0:
+                gc.collect()
+                torch._C._cuda_clearCublasWorkspaces()
+                torch.cuda.empty_cache()
+                break
+
+        # Acquires caching allocator and driver statistics before the test is run
+        for i in range(num_devices):
+            self.caching_allocator_befores.append(torch.cuda.memory_allocated(i))
+            bytes_free, bytes_total = torch.cuda.mem_get_info(i)
+            driver_mem_allocated = bytes_total - bytes_free
+            self.driver_befores.append(driver_mem_allocated)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        # Don't check for leaks if an exception was thrown
+        if exc_type is not None:
+            return
+
+        # Compares caching allocator before/after statistics
+        # An increase in allocated memory is a discrepancy indicating a possible
+        #   memory leak
+        discrepancy_detected = False
+        num_devices = torch.cuda.device_count()
+        for i in range(num_devices):
+            # avoid counting cublasWorkspace allocations
+            torch._C._cuda_clearCublasWorkspaces()
+            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+
+            if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
+                discrepancy_detected = True
+                break
+
+        # Short-circuits if no discrepancy detected
+        if not discrepancy_detected:
+            return
+
+        # Validates the discrepancy persists after garbage collection and
+        #   is confirmed by the driver API
+
+        # NOTE: driver API iscrepancies alone are ignored because with the jiterator
+        #   some tests may permanently increase the CUDA context size and
+        #   that will appear as a driver memory leak but is the expected behavior.
+
+        # GCs and clears the cache
+        gc.collect()
+        torch.cuda.empty_cache()
+
+        for i in range(num_devices):
+
+            discrepancy_detected = True
+
+            # Query memory multiple items to ensure leak was not transient
+            for _ in range(3):
+                caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+                bytes_free, bytes_total = torch.cuda.mem_get_info(i)
+                driver_mem_allocated = bytes_total - bytes_free
+
+                caching_allocator_discrepancy = False
+                driver_discrepancy = False
+
+                if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
+                    caching_allocator_discrepancy = True
+
+                if driver_mem_allocated > self.driver_befores[i]:
+                    driver_discrepancy = True
+
+                if not (caching_allocator_discrepancy or driver_discrepancy):
+                    # Leak was false positive, exit loop
+                    discrepancy_detected = False
+                    break
+
+            if not discrepancy_detected:
+                continue
+
+            if caching_allocator_discrepancy and not driver_discrepancy:  # type: ignore[possibly-undefined]
+                # Just raises a warning if the leak is not validated by the
+                #   driver API
+                # NOTE: this may be a problem with how the caching allocator collects its
+                #   statistics or a leak too small to trigger the allocation of an
+                #   additional block of memory by the CUDA driver
+                msg = ("CUDA caching allocator reports a memory leak not "  # type: ignore[possibly-undefined]
+                       f"verified by the driver API in {self.name}! "
+                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
+                       f"and is now reported as {caching_allocator_mem_allocated} "  # type: ignore[possibly-undefined]
+                       f"on device {i}. "
+                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")  # type: ignore[possibly-undefined]
+                warnings.warn(msg, stacklevel=2)
+            elif caching_allocator_discrepancy and driver_discrepancy:  # type: ignore[possibly-undefined]
+                # A caching allocator discrepancy validated by the driver API is a
+                #   failure (except on ROCm, see below)
+                msg = (f"CUDA driver API confirmed a leak in {self.name}! "  # type: ignore[possibly-undefined]
+                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
+                       f"and is now reported as {caching_allocator_mem_allocated} "  # type: ignore[possibly-undefined]
+                       f"on device {i}. "
+                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")  # type: ignore[possibly-undefined]
+
+                raise RuntimeError(msg)
+
+@contextmanager
+def skip_exception_type(exc_type):
+    try:
+        yield
+    except exc_type as e:
+        raise unittest.SkipTest(f"not implemented: {e}") from e
+
+@contextmanager
+def print_repro_on_failure(repro_parts):
+    try:
+        yield
+    except unittest.SkipTest:
+        raise
+    except Exception as e:
+        # Get the index of the sample input that failed the test if possible.
+        sample_isolation_prefix = ""
+        tracked_input = getattr(e, "_tracked_input", None)
+        if tracked_input is not None:
+            sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}"
+
+        repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts)))
+
+        open_source_signpost(
+            subsystem="test_repros",
+            name="test_failure",
+            parameters=json.dumps(
+                {
+                    "repro": " ".join(filter(None, (sample_isolation_prefix, *repro_parts))),
+                }
+            ),
+        )
+
+        repro_msg = f"""
+To execute this test, run the following from the base repo dir:
+    {repro_str}
+
+This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
+
+        # NB: Hacking the exception args is the cleanest way I've found to append
+        # failure reproduction info without poisoning the stack trace.
+        if len(e.args) >= 1:
+            e.args = (f"{e.args[0]}\n{repro_msg}", *e.args[1:])
+        raise
+
+#  "min_satisfying_examples" setting has been deprecated in hypothesis
+#  3.56.0 and removed in hypothesis 4.x
+try:
+    import hypothesis
+
+    def settings(*args, **kwargs):
+        if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0):
+            kwargs.pop('min_satisfying_examples')
+        return hypothesis.settings(*args, **kwargs)
+
+
+    hypothesis.settings.register_profile(
+        "pytorch_ci",
+        settings(
+            derandomize=True,
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=50,
+            verbosity=hypothesis.Verbosity.normal))
+    hypothesis.settings.register_profile(
+        "dev",
+        settings(
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=10,
+            verbosity=hypothesis.Verbosity.normal))
+    hypothesis.settings.register_profile(
+        "debug",
+        settings(
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=1000,
+            verbosity=hypothesis.Verbosity.verbose))
+
+    hypothesis.settings.load_profile(
+        "pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev')
+    )
+except ImportError:
+    warnings.warn('Fail to import hypothesis in common_utils, tests are not derandomized', ImportWarning, stacklevel=2)
+
+# Used in check_if_enable to see if a test method should be disabled by an issue,
+# sanitizes a test method name from appended suffixes by @dtypes parametrization.
+# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should
+# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32
+def remove_device_and_dtype_suffixes(test_name: str) -> str:
+    # import statement is localized to avoid circular dependency issues with common_device_type.py
+    from torch.testing._internal.common_device_type import get_device_type_test_bases
+    device_suffixes = [x.device_type for x in get_device_type_test_bases()]
+    dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()]
+
+    test_name_chunks = test_name.split("_")
+    if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes:
+        if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes:
+            return "_".join(test_name_chunks[0:-2])
+        return "_".join(test_name_chunks[0:-1])
+    return test_name
+
+
+def check_if_enable(test: unittest.TestCase):
+    classname = str(test.__class__).split("'")[1].split(".")[-1]
+    sanitized_testname = remove_device_and_dtype_suffixes(test._testMethodName)
+
+    def matches_test(target: str):
+        target_test_parts = target.split()
+        if len(target_test_parts) < 2:
+            # poorly formed target test name
+            return False
+        target_testname = target_test_parts[0]
+        target_classname = target_test_parts[1][1:-1].split(".")[-1]
+        # if test method name or its sanitized version exactly matches the disabled
+        # test method name AND allow non-parametrized suite names to disable
+        # parametrized ones (TestSuite disables TestSuiteCPU)
+        return classname.startswith(target_classname) and (target_testname in (test._testMethodName, sanitized_testname))
+
+    if any(matches_test(x) for x in slow_tests_dict):
+        getattr(test, test._testMethodName).__dict__['slow_test'] = True
+        if not TEST_WITH_SLOW:
+            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+
+    if not IS_SANDCASTLE:
+        should_skip = False
+        skip_msg = ""
+
+        for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
+            if matches_test(disabled_test):
+                platform_to_conditional: dict = {
+                    "mac": IS_MACOS,
+                    "macos": IS_MACOS,
+                    "win": IS_WINDOWS,
+                    "windows": IS_WINDOWS,
+                    "linux": IS_LINUX,
+                    "rocm": TEST_WITH_ROCM,
+                    "xpu": TEST_XPU,
+                    "asan": TEST_WITH_ASAN,
+                    "dynamo": TEST_WITH_TORCHDYNAMO,
+                    "dynamo_wrapped": TEST_WITH_TORCHDYNAMO,
+                    "inductor": TEST_WITH_TORCHINDUCTOR,
+                    "slow": TEST_WITH_SLOW,
+                }
+
+                invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms))
+                if len(invalid_platforms) > 0:
+                    invalid_plats_str = ", ".join(invalid_platforms)
+                    valid_plats = ", ".join(platform_to_conditional.keys())
+
+                    print(f"Test {disabled_test} is disabled for some unrecognized ",
+                          f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ",
+                          'assigned to this flaky test, changing "Platforms: ..." to a comma separated ',
+                          f"subset of the following (or leave it blank to match all platforms): {valid_plats}")
+
+                    # Sanitize the platforms list so that we continue to disable the test for any valid platforms given
+                    platforms = list(filter(lambda p: p in platform_to_conditional, platforms))
+
+                if platforms == [] or any(platform_to_conditional[platform] for platform in platforms):
+                    should_skip = True
+                    skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \
+                        f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \
+                        "If you're seeing this on your local machine and would like to enable this test, " \
+                        "please make sure CI is not set and you are not using the flag --import-disabled-tests."
+                    break
+
+        if should_skip and not RERUN_DISABLED_TESTS:
+            # Skip the disabled test when not running under --rerun-disabled-tests verification mode
+            raise unittest.SkipTest(skip_msg)
+
+        if not should_skip and RERUN_DISABLED_TESTS:
+            # Probably test has disable issue but not for this platform
+            skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \
+                " disabled tests are run"
+            raise unittest.SkipTest(skip_msg)
+
+    if TEST_SKIP_FAST:
+        if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
+            raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
+
+
+# `TestCase.assertEqual` is very permissive and coerced the inputs into a format that could be compared. This is very
+# convenient when writing tests, but not so much while reviewing them. By default, the comparison `Pair` framework of
+# `torch.testing._comparison.are_equal`, used for example by the public testing function
+# `torch.testing.assert_close`, is more strict. In order to use the same framework and thus reduce the divergence
+# between internal and external comparison logic as much as possible, we define some "relaxed" pairs here. They only
+# change the supported inputs, but the comparison logic is the same.
+# TODO: Revisit the relaxed pairs and check how much work it is to fix the tests that would fail without the relaxation.
+
+class RelaxedBooleanPair(BooleanPair):
+    """Pair for boolean-like inputs.
+
+    In contrast to the builtin :class:`BooleanPair`, this class also supports one input being a number or a single
+    element tensor-like.
+    """
+    _supported_number_types = NumberPair(0, 0)._supported_types
+
+    def _process_inputs(self, actual, expected, *, id):
+        # We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a
+        # number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans.
+        tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
+        other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types)
+        if not (
+            (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
+            or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
+        ):
+            self._inputs_not_supported()
+
+        return [self._to_bool(input, id=id) for input in (actual, expected)]
+
+    def _to_bool(self, bool_like, *, id):
+        if isinstance(bool_like, np.number):
+            return bool(bool_like.item())
+        elif type(bool_like) in self._supported_number_types:
+            return bool(bool_like)
+        elif isinstance(bool_like, (torch.Tensor, np.ndarray)):
+            numel = bool_like.numel() if isinstance(bool_like, torch.Tensor) else bool_like.size
+            if numel > 1:
+                self._fail(
+                    ValueError,
+                    f"Only single element tensor-likes can be compared against a boolean. "
+                    f"Got {numel} elements instead.",
+                    id=id
+                )
+
+            return bool(bool_like.item())
+        else:
+            return super()._to_bool(bool_like, id=id)
+
+
+class RelaxedNumberPair(NumberPair):
+    """Pair for number-like inputs.
+
+    In contrast to the builtin :class:`NumberPair`, this class also supports one input being a single element
+    tensor-like or a :class:`enum.Enum`. (D)Type checks are disabled, meaning comparing 1 to 1.0 succeeds even when
+    ``check_dtype=True`` is passed.
+
+    In addition, this class uses looser default tolerances for :class:`float` and :class:`complex` inputs. Also
+    supports overriding the absolute and relative tolerance through the ``@precisionOverride`` and
+    ``@toleranceOverride`` decorators.
+    """
+    _TYPE_TO_DTYPE = {
+        int: torch.int64,
+        float: torch.float32,
+        complex: torch.complex64,
+    }
+
+    def __init__(
+            self, actual, expected, *, rtol_override=0.0, atol_override=0.0, check_dtype=None, **other_parameters
+    ) -> None:
+        super().__init__(actual, expected, check_dtype=False, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _process_inputs(self, actual, expected, *, id):
+        # We require only one of the inputs of the inputs to be a number and the other can also be a number or a single
+        # element tensor or array, whereas in default NumberPair both inputs have to be numbers.
+        tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
+        other_supported_types = (*self._supported_types, *tensor_or_array_types)
+        if not (
+                (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
+                or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
+        ):
+            self._inputs_not_supported()
+
+        return [self._to_number(input, id=id) for input in (actual, expected)]
+
+    def _to_number(self, number_like, *, id):
+        if isinstance(number_like, (torch.Tensor, np.ndarray)):
+            numel = number_like.numel() if isinstance(number_like, torch.Tensor) else number_like.size
+            if numel > 1:
+                self._fail(
+                    ValueError,
+                    f"Only single element tensor-likes can be compared against a number. "
+                    f"Got {numel} elements instead.",
+                    id=id
+                )
+            number = number_like.item()
+            if isinstance(number, bool):
+                number = int(number)
+
+            return number
+        elif isinstance(number_like, Enum):
+            return int(number_like)  # type: ignore[call-overload]
+        else:
+            number = super()._to_number(number_like, id=id)
+            if type(number) not in self._TYPE_TO_DTYPE:
+                self._inputs_not_supported()
+            return number
+
+
+class TensorOrArrayPair(TensorLikePair):
+    """Pair for tensor-like inputs.
+
+    On the one hand this class is stricter than the builtin :class:`TensorLikePair` since it only allows instances of
+    :class:`torch.Tensor` and :class:`numpy.ndarray` rather than allowing any tensor-like than can be converted into a
+    tensor. On the other hand this class is looser since it converts all inputs into tensors with no regard of their
+    relationship, e.g. comparing a :class:`torch.Tensor` to :class:`numpy.ndarray` is fine.
+
+    In addition, this class supports overriding the absolute and relative tolerance through the ``@precisionOverride``
+    and ``@toleranceOverride`` decorators.
+    """
+    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
+        super().__init__(actual, expected, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _process_inputs(self, actual, expected, *, id, allow_subclasses):
+        self._check_inputs_isinstance(actual, expected, cls=(torch.Tensor, np.ndarray))
+
+        actual, expected = (self._to_tensor(input) for input in (actual, expected))
+        for tensor in (actual, expected):
+            self._check_supported(tensor, id=id)
+        return actual, expected
+
+
+class TypedStoragePair(TensorLikePair):
+    """Pair for :class:`torch.storage.TypedStorage` inputs."""
+    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
+        self._check_inputs_isinstance(actual, expected, cls=torch.storage.TypedStorage)
+        super().__init__(actual, expected, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _to_tensor(self, typed_storage):
+        return torch.tensor(
+            typed_storage._untyped_storage,
+            dtype={
+                torch.quint8: torch.uint8,
+                torch.quint4x2: torch.uint8,
+                torch.quint2x4: torch.uint8,
+                torch.qint32: torch.int32,
+                torch.qint8: torch.int8
+            }.get(typed_storage.dtype, typed_storage.dtype),
+            device=typed_storage.device,
+        )
+
+
+class UnittestPair(Pair):
+    """Fallback ABC pair that handles non-numeric inputs.
+
+    To avoid recreating the mismatch messages of :meth:`unittest.TestCase.assertEqual`, this pair simply wraps it in
+    order to use it with the :class:`Pair` "framework" from :func:`are_equal`.
+
+    Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support.
+    """
+    CLS: Union[type, tuple[type, ...]]
+    TYPE_NAME: Optional[str] = None
+
+    def __init__(self, actual, expected, **other_parameters):
+        self._check_inputs_isinstance(actual, expected, cls=self.CLS)
+        super().__init__(actual, expected, **other_parameters)
+
+    def compare(self):
+        test_case = unittest.TestCase()
+
+        try:
+            return test_case.assertEqual(self.actual, self.expected)
+        except test_case.failureException as error:
+            msg = str(error)
+
+        type_name = self.TYPE_NAME or (self.CLS if isinstance(self.CLS, type) else self.CLS[0]).__name__
+        self._fail(AssertionError, f"{type_name.title()} comparison failed: {msg}")
+
+
+class StringPair(UnittestPair):
+    CLS = (str, bytes)
+    TYPE_NAME = "string"
+
+
+class SetPair(UnittestPair):
+    CLS = set
+
+
+class TypePair(UnittestPair):
+    CLS = type
+
+
+class ObjectPair(UnittestPair):
+    CLS = object
+
+
+# This implements a variant of assertRaises/assertRaisesRegex where we first test
+# if the exception is NotImplementedError, and if so just skip the test instead
+# of failing it.
+#
+# This is implemented by inheriting from the (private) implementation of
+# assertRaises from unittest.case, and slightly tweaking it for this new
+# behavior.  The year is 2021: this private class hierarchy hasn't changed since
+# 2010, seems low risk to inherit from.
+class AssertRaisesContextIgnoreNotImplementedError(unittest.case._AssertRaisesContext):
+    def __exit__(self, exc_type, exc_value, tb):
+        if exc_type is not None and issubclass(exc_type, NotImplementedError):
+            self.test_case.skipTest(f"not_implemented: {exc_value}")  # type: ignore[attr-defined]
+        return super().__exit__(exc_type, exc_value, tb)
+
+
+@contextmanager
+def set_warn_always_context(new_val: bool):
+    old_val = torch.is_warn_always_enabled()
+    torch.set_warn_always(new_val)
+    try:
+        yield
+    finally:
+        torch.set_warn_always(old_val)
+
+
+class NoTest:
+    # causes pytest to not recognize this class as a test
+    __test__ = False
+
+
+class TestCase(expecttest.TestCase):
+    # NOTE: "precision" lets classes and generated tests set minimum
+    # atol values when comparing tensors. Used by @precisionOverride and @toleranceOverride, for
+    # example.
+    # NOTE: "rel_tol" lets classes and generated tests set minimum
+    # rtol values when comparing tensors. Used by @toleranceOverride, for example.
+    _precision: float = 0
+    _rel_tol: float = 0
+
+    # Toggles whether to assert that `torch.get_default_dtype()` returns
+    # `torch.float` when `setUp` and `tearDown` are called.
+    _default_dtype_check_enabled: bool = False
+
+    # Always use difflib to print diffs on multi line equality.
+    # Undocumented feature in unittest
+    _diffThreshold = sys.maxsize
+    maxDiff = None
+
+    # checker to early terminate test suite if unrecoverable failure occurs.
+    def _should_stop_test_suite(self):
+        if torch.cuda.is_initialized():
+            # CUDA device side error will cause subsequence test cases to fail.
+            # stop entire test suite if catches RuntimeError during torch.cuda.synchronize().
+            try:
+                torch.cuda.synchronize()
+            except RuntimeError as rte:
+                print("TEST SUITE EARLY TERMINATION due to torch.cuda.synchronize() failure", file=sys.stderr)
+                print(str(rte), file=sys.stderr)
+                return True
+            return False
+        else:
+            return False
+
+    @property
+    def precision(self) -> float:
+        return self._precision
+
+    @precision.setter
+    def precision(self, prec: float) -> None:
+        self._precision = prec
+
+    @property
+    def rel_tol(self) -> float:
+        return self._rel_tol
+
+    @rel_tol.setter
+    def rel_tol(self, prec: float) -> None:
+        self._rel_tol = prec
+
+    _do_cuda_memory_leak_check = False
+    _do_cuda_non_default_stream = False
+
+    # When True, if a test case raises a NotImplementedError, instead of failing
+    # the test, skip it instead.
+    _ignore_not_implemented_error = False
+
+    def __init__(self, method_name='runTest', methodName='runTest'):
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+
+        test_method = getattr(self, method_name, None)
+        if test_method is not None:
+            # Wraps the tested method if we should do CUDA memory check.
+            if TEST_CUDA_MEM_LEAK_CHECK:
+                self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
+                # FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
+                if self._do_cuda_memory_leak_check and not IS_WINDOWS:
+                    self.wrap_with_cuda_policy(method_name, self.assertLeaksNoCudaTensors)
+
+            # Wraps the tested method if we should enforce non default CUDA stream.
+            self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True)
+            if self._do_cuda_non_default_stream and not IS_WINDOWS:
+                self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream)
+
+            if self._ignore_not_implemented_error:
+                self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError))
+
+            if PRINT_REPRO_ON_FAILURE:
+                try:
+                    def _get_rel_test_path(abs_test_path):
+                        # Attempt to get relative path based on the "test" dir.
+                        # In CI, the working dir is not guaranteed to be the base repo dir so
+                        # we can't just compute relative path from that.
+                        parts = Path(abs_test_path).parts
+                        for i, part in enumerate(parts):
+                            if part == "test":
+                                base_dir = os.path.join(*parts[:i]) if i > 0 else ''
+                                return os.path.relpath(abs_test_path, start=base_dir)
+
+                        # Can't determine containing dir; just return the test filename.
+                        # The path isn't strictly correct but it's arguably better than nothing.
+                        return os.path.split(abs_test_path)[1]
+
+                    abs_test_path = inspect.getfile(type(self))
+                    test_filename = _get_rel_test_path(abs_test_path)
+                    class_name = type(self).__name__
+                    test_run_cmd = f"python {test_filename} {class_name}.{method_name}"
+                    env_var_prefix = TestEnvironment.repro_env_var_prefix()
+                    repro_parts = [env_var_prefix, test_run_cmd]
+                    self.wrap_with_policy(
+                        method_name,
+                        lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts))
+                except Exception as e:
+                    # Don't fail entirely if we can't get the test filename
+                    log.info("could not print repro string", extra=str(e))  # type: ignore[arg-type]
+
+    def assertLeaksNoCudaTensors(self, name=None):
+        name = self.id() if name is None else name
+        return CudaMemoryLeakCheck(self, name)
+
+    def enforceNonDefaultStream(self):
+        return CudaNonDefaultStream()
+
+    def _remove_ansi_escape(self, input):
+        # 7-bit C1 ANSI sequences
+        ansi_escape = re.compile(r'''
+            \x1B  # ESC
+            (?:   # 7-bit C1 Fe (except CSI)
+                [@-Z\\-_]
+            |     # or [ for CSI, followed by a control sequence
+                \[
+                [0-?]*  # Parameter bytes
+                [ -/]*  # Intermediate bytes
+                [@-~]   # Final byte
+            )
+        ''', re.VERBOSE)
+        return ansi_escape.sub('', input)
+
+    def remove_comment_lines(self, input_string):
+        lines = input_string.split('\n')
+        filtered_lines = [line for line in lines if not line.strip().startswith('#')]
+        return '\n'.join(filtered_lines)
+
+    def remove_empty_lines(self, input_string):
+        lines = input_string.split('\n')
+        filtered_lines = [line for line in lines if line.strip() != '']
+        return '\n'.join(filtered_lines)
+
+    # ignore comments will ignore lines that starts with # after being stripped
+    def assertExpectedInline(self, actual, expect, skip=0, ignore_comments=False, ignore_empty_lines=False):
+        actual = actual if isinstance(actual, str) else str(actual)
+        actual = self._remove_ansi_escape(actual)
+        expect = self._remove_ansi_escape(expect)
+        if ignore_comments:
+            actual = self.remove_comment_lines(actual)
+            expect = self.remove_comment_lines(expect)
+
+        if ignore_empty_lines:
+            actual = self.remove_empty_lines(actual)
+            expect = self.remove_empty_lines(expect)
+
+        return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1)
+
+    # Munges exceptions that internally contain stack traces, using munge_exc
+    def assertExpectedInlineMunged(
+        self, exc_type, callable, expect, *, skip=0, suppress_suffix=True, post_munge=None,
+    ):
+        try:
+            callable()
+        except exc_type as e:
+            munged = munge_exc(e, suppress_suffix=suppress_suffix, skip=skip + 1)
+            if post_munge:
+                munged = post_munge(munged)
+            self.assertExpectedInline(
+                munged, expect, skip=skip + 1
+            )
+            return
+        self.fail(msg="Did not raise when expected to")
+
+    def assertLogs(self, logger=None, level=None):
+        if logger is None:
+            logger = logging.getLogger("torch")
+        return super().assertLogs(logger, level)
+
+    def assertNoLogs(self, logger=None, level=None):
+        if logger is None:
+            logger = logging.getLogger("torch")
+        return super().assertNoLogs(logger, level)
+
+    def wrap_with_cuda_policy(self, method_name, policy):
+        test_method = getattr(self, method_name)
+        # the import below may initialize CUDA context, so we do it only if
+        # self._do_cuda_memory_leak_check or self._do_cuda_non_default_stream
+        # is True.
+        # TODO: sure looks like we unconditionally initialize the context here
+        # -- ezyang
+        from torch.testing._internal.common_cuda import TEST_CUDA
+        fullname = self.id().lower()  # class_name.method_name
+        if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname):
+            setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
+
+    def wrap_with_policy(self, method_name, policy):
+        test_method = getattr(self, method_name)
+        setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
+
+    # A policy is a zero-argument function that returns a context manager.
+    # We don't take the context manager directly as it may be necessary to
+    # construct it once per test method
+    def wrap_method_with_policy(self, method, policy):
+        # Assumes that `method` is the tested function in `self`.
+        # NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope
+        #       alive, so this cannot be done in setUp and tearDown because
+        #       tearDown is run unconditionally no matter whether the test
+        #       passes or not. For the same reason, we can't wrap the `method`
+        #       call in try-finally and always do the check.
+        @wraps(method)
+        def wrapper(self, *args, **kwargs):
+            with policy():
+                method(*args, **kwargs)
+        return types.MethodType(wrapper, self)
+
+    def wrap_with_cuda_memory_check(self, method):
+        return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
+
+    def _dynamo_test_key(self):
+        return f"{self.__class__.__name__}.{self._testMethodName}"
+
+    def compile_fn(self, fn, backend, nopython):
+        # Allows subclasses to control compilation
+        return torch._dynamo.optimize(backend, nopython=nopython)(fn)
+
+    def _run_custom(self, result=None):
+        using_unittest = isinstance(result, unittest.TestResult)
+
+        super_run = super().run
+        test_cls = super_run.__self__  # type: ignore[attr-defined]
+
+        # Are we compiling?
+        compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
+        # Is the class strict and compiling?
+        strict_default = False
+        should_reset_dynamo = False
+
+        # We disable size_asserts for test_ops since some tests fail
+        # due to mismatch of strides returned from eager v.s. meta kernels
+        # Only some of the ops has this problem, but since tests in
+        # test_op.py are parametrized, it's hard to do this specifically
+        # for the affected ops.
+        # It's not a big deal since these problems are captured by
+        # test_torchinductor_opinfo.py as well.
+        should_disable_size_asserts = False
+        if compiled:
+            try:
+                path = inspect.getfile(type(test_cls))
+                full_path = os.path.abspath(path)
+                match = re.match(r".*/test/(.*).py", full_path)
+                if match is not None:
+                    filename = match.group(1)
+                    if TEST_WITH_TORCHINDUCTOR:
+                        from .dynamo_test_failures import FIXME_inductor_non_strict
+                        strict_default = filename not in FIXME_inductor_non_strict
+                        should_reset_dynamo = True
+
+                        if filename == "test_ops":
+                            should_disable_size_asserts = True
+                    else:
+                        strict_default = True
+            # inspect.getfile can fail with these
+            except (OSError, TypeError):
+                pass
+            if "STRICT_DEFAULT" in os.environ:
+                if os.environ["STRICT_DEFAULT"] == "1":
+                    strict_default = True
+
+        strict_mode = False
+        if compiled:
+            test_method = getattr(self, self._testMethodName)
+            if hasattr(test_method, "dynamo_strict"):
+                strict_mode = test_method.dynamo_strict
+            elif hasattr(test_cls, "dynamo_strict"):
+                strict_mode = test_cls.dynamo_strict
+            else:
+                strict_mode = strict_default
+        nopython = getattr(test_cls, "dynamo_strict_nopython", False) and compiled
+
+        if strict_mode or should_reset_dynamo:
+            torch._dynamo.reset()
+
+        torch.compiler.set_stance("default")
+
+        # TODO: Remove this; this is grandfathered in because we suppressed errors
+        # on test suite previously
+        # When strict mode is False, suppress_errors is True
+        if compiled:
+            suppress_errors = not strict_mode
+        else:
+            suppress_errors = torch._dynamo.config.suppress_errors
+
+        maybe_disable_size_asserts = (
+            torch._inductor.config.patch(size_asserts=False)
+            if should_disable_size_asserts
+            else contextlib.nullcontext()
+        )
+
+        with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
+            if TEST_WITH_AOT_EAGER:
+                super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython)
+            elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
+                if TEST_WITH_TORCHINDUCTOR:
+                    super_run = self.compile_fn(super_run, "inductor", nopython)
+                else:
+                    # Assume eager-generated GraphModules will not error out.
+                    # If we do, this is probably a Dynamo bug!
+                    super_run = self.compile_fn(super_run, "eager_noexcept", nopython)
+
+                key = self._dynamo_test_key()
+
+                def expect_failure(f, file_name):
+                    @wraps(f)
+                    def wrapper(*args, **kwargs):
+                        try:
+                            f(*args, **kwargs)
+                        except BaseException as e:  # noqa: B036
+                            self.skipTest(e)
+                        raise RuntimeError(f"Unexpected success, please remove `{file_name}`")
+                    return wrapper
+
+                if TEST_WITH_TORCHINDUCTOR:
+                    subdir = "test/inductor_expected_failures"
+                    from .dynamo_test_failures import inductor_expected_failures as expected_failures
+                else:
+                    subdir = "test/dynamo_expected_failures"
+                    from .dynamo_test_failures import dynamo_expected_failures as expected_failures
+
+                if key in expected_failures:
+                    method = getattr(self, self._testMethodName)
+                    file_name = os.path.join(subdir, key)
+                    setattr(self, self._testMethodName, expect_failure(method, file_name))
+
+                def ignore_failure(f, file_name):
+                    @wraps(f)
+                    def wrapper(*args, **kwargs):
+                        try:
+                            f(*args, **kwargs)
+                        except BaseException as e:  # noqa: B036
+                            self.skipTest(e)
+                        method = getattr(self, self._testMethodName)
+                        if getattr(method, "__unittest_expecting_failure__", False):
+                            self.skipTest("unexpected success")
+                        else:
+                            self.skipTest(f"This test passed, maybe we can remove `{file_name}`")
+                    return wrapper
+
+                if TEST_WITH_TORCHINDUCTOR:
+                    subdir = "test/inductor_skips"
+                    from .dynamo_test_failures import inductor_skips as skips
+                else:
+                    subdir = "test/dynamo_skips"
+                    from .dynamo_test_failures import dynamo_skips as skips
+
+                if key in skips:
+                    method = getattr(self, self._testMethodName)
+                    file_name = os.path.join(subdir, key)
+                    setattr(self, self._testMethodName, ignore_failure(method, file_name))
+
+                from .dynamo_test_failures import compiled_autograd_skips
+                if torch._dynamo.config.compiled_autograd and key in compiled_autograd_skips:
+                    # Still run the test, but with compiled autograd disabled
+                    super_run = runWithoutCompiledAutograd()(super_run)
+
+            super_run(result=result)
+
+        if strict_mode or should_reset_dynamo:
+            torch._dynamo.reset()
+        elif torch._dynamo.config.compiled_autograd:
+            torch._dynamo.compiled_autograd.reset()
+
+        # Early terminate test if necessary.  If using pytest, use the -x flag instead
+        if using_unittest and self._should_stop_test_suite():
+            if result.wasSuccessful():
+                case = TestCase()
+                if TEST_SAVE_XML is not None:
+                    # This is a big hacky, XMLRunner modifies expected type from TestCase to TestInfo
+                    # Create dummy TestInfo to record results correctly
+                    from xmlrunner.result import _TestInfo  # type: ignore[import]
+                    case = _TestInfo(result, case)
+                    case.output = _TestInfo.ERROR  # type: ignore[attr-defined]
+                    case.elapsed_time = 0.0  # type: ignore[attr-defined]
+                    case.test_description = "TestSuiteEarlyFailure"  # type: ignore[attr-defined]
+                # This shouldn't really happen, but if does add fake failure
+                # For more details see https://github.com/pytorch/pytorch/issues/71973
+                result.failures.append((case, "TestSuite execution was aborted early"))
+                assert result.wasSuccessful() is False
+            result.stop()
+
+
+    def run(self, result=None):
+        with contextlib.ExitStack() as stack:
+            if TEST_WITH_CROSSREF:
+                stack.enter_context(CrossRefMode())
+            self._run_custom(
+                result=result,
+            )
+
+    def setUp(self):
+        check_if_enable(self)
+        set_rng_seed()
+
+        # Save global check sparse tensor invariants state that can be
+        # restored from tearDown:
+        self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled()
+
+        # Enable invariant checks for all sparse tensors constructions
+        # including the unsafe ones. If this is not desired for some
+        # test case, use check_invariants=False optional argument to
+        # sparse tensor constructors or
+        # @torch.sparse.check_sparse_tensor_invariants(False)
+        # decorator to disable the invariant checks.
+        torch.sparse.check_sparse_tensor_invariants.enable()
+
+        if self._default_dtype_check_enabled:
+            assert torch.get_default_dtype() == torch.float
+
+        # attempt to reset some global state at the end of the test
+        self._prev_grad_state = torch.is_grad_enabled()
+
+    def tearDown(self):
+        # There exists test cases that override TestCase.setUp
+        # definition, so we cannot assume that _check_invariants
+        # attribute is defined in general.
+        if hasattr(self, '_check_invariants'):
+            # Restore the global check sparse tensor invariants state
+            if self._check_invariants:
+                torch.sparse.check_sparse_tensor_invariants.enable()
+            else:
+                torch.sparse.check_sparse_tensor_invariants.disable()
+
+        if self._default_dtype_check_enabled:
+            assert torch.get_default_dtype() == torch.float
+
+        # attribute may not be defined, per above
+        if hasattr(self, '_prev_grad_state'):
+            torch.set_grad_enabled(self._prev_grad_state)
+
+    @staticmethod
+    def _make_crow_indices(n_rows, n_cols, nnz,
+                           *, device, dtype, random=True):
+        """Return crow_indices of a CSR tensor with size (n_rows, n_cols) and
+        the number of specified elements nnz.
+
+        If random is True, the column counts of rows are in random
+        order. Otherwise, the column counts of rows are defined by the
+        used sampling method.
+
+        Sampling method
+        ---------------
+
+        The used sampling method was introduced in
+        https://pearu.github.io/csr_sampling.html, and here we give
+        only an overall description of the method.
+
+        Notice that crow_indices can be defined as cumsum(counts)
+        where counts is a sequence of non-negative integers satisfying
+        the following conditions:
+
+          len(counts) == n_rows + 1
+          counts.max() <= n_cols
+
+        while counts[i + 1] is interpreted as the number of specified
+        elements in the i-th row.
+
+        The used sampling method aims at increasing the diversity of
+        CSR samples, that is, a CSR sample should contain (i) rows
+        that are all filled, (ii) rows with no elements at all, and
+        (iii) rows that are partially filled. At the same time and for
+        the given total number of specified elements (nnz), there
+        should be minimal preference to rows with a given number of
+        elements.  To achieve this, the sampling method is built-up on
+        using a sawteeth model for counts. In the simplest case, we
+        would have
+
+          counts = arange(n_rows + 1) % (n_cols + 1)
+
+        that has equal number of all possible column counts per row.
+        This formula can be used only for specific input values of
+        n_rows, n_cols, and nnz. To generalize this model to any
+        combinations of inputs, the counts model above is extended
+        with an incomplete sawtooth, and the right and lower
+        rectangular parts that will guarantee that
+
+          counts.sum() == nnz
+
+        for any combination of n_rows, n_cols, and nnz. Basically,
+        we'll find a maximal window in (n_rows + 1, n_cols + 1)-grid
+        that is able to hold a sequence of sawteeth and so-called
+        final correction, while the external part of the window is
+        filled with counts to meet the nnz constraint exactly.
+        """
+        assert 0 <= nnz <= n_rows * n_cols, (nnz, n_rows, n_cols)
+
+        def sawteeth(n, m):
+            # return the total number of counts in the sequence of
+            # sawteeth where n and m define a window in (n_rows+1,
+            # n_cols+1) rectangle where the sequence of sawteeth
+            # perfectly fit.
+            M = (n_cols - m) * (n_cols - m + 1) // 2
+            K = (n_rows - n) % (n_cols - m + 1)
+            return M * ((n_rows - n) // (n_cols - m + 1)) + K * (K - 1) // 2
+
+        # Different from the original method description, here counts
+        # has leading 0 required by crow_indices:
+        counts = torch.zeros(n_rows + 1, dtype=dtype, device=torch.device('cpu'))
+
+        n = m = 0
+        N = sawteeth(n, m)
+        if N and nnz >= max(N, n_cols):
+            # determine the width of the sawteeth window. We use bisection to solve
+            #   N(n, 0) == 0 or nnz - n * n_cols < max(N(n, 0), n_cols)
+            # for n
+            n_left = n
+            n_right = n_rows - 1
+            N_right = sawteeth(n_right, m)
+            while n_right - n_left > 1:
+                n_middle = (n_left + n_right) // 2
+                N_middle = sawteeth(n_middle, m)
+                if N_middle == 0 or nnz - n_middle * n_cols < max(N_middle, n_cols):
+                    n_right, N_right = n_middle, N_middle
+                else:
+                    n_left = n_middle
+            n, N = n_right, N_right
+            # fill the right rectangle with counts:
+            assert n
+            counts[-n:].fill_(n_cols)
+
+        if N and nnz - n * n_cols >= max(N, n_rows - n):
+            # determine the height of the sawteeth window. We use bisection to solve
+            #   N(n, m) == 0 or nnz - n * n_cols - m * (n_rows - n) < max(N(n, m), n_rows - n)
+            # for m.
+            m_left = m
+            m_right = n_cols - 1
+            N_right = sawteeth(n, m_right)
+            while m_right - m_left > 1:
+                m_middle = (m_left + m_right) // 2
+                N_middle = sawteeth(n, m_middle)
+                if N_middle == 0 or nnz - n * n_cols - m_middle * (n_rows - n) < max(N_middle, n_rows - n):
+                    m_right, N_right = m_middle, N_middle
+                else:
+                    m_left = m_middle
+            m, N = m_right, N_right
+            # fill the bottom rectangle with counts:
+            assert m
+            counts[1:n_rows - n + 1].fill_(m)
+
+        if N:
+            # fill the sawteeth window with counts
+            q, r = divmod(nnz - n * n_cols - m * (n_rows - n),
+                          (n_cols - m) * (n_cols - m + 1) // 2)
+            p = 1 + q * (n_cols - m + 1)
+            k = math.isqrt(2 * r)
+            if k * (k + 1) > 2 * r:
+                k -= 1
+            corr = r - k * (k + 1) // 2
+            assert not ((p > 1) and (m > 0))  # full sawteeth are never on top of a bottom rectangle
+            # sequence of full sawteeth:
+            counts[1:p] = torch.arange(p - 1, dtype=dtype, device=counts.device) % (n_cols - m + 1)
+            # incomplete sawtooth:
+            counts[p:p + k + 1] += torch.arange(k + 1, dtype=dtype, device=counts.device)
+        else:
+            # given input does not support sawteeth
+            p = 1
+            corr = nnz - n * n_cols - m * (n_rows - n)
+
+        # correction that will guarantee counts.sum() == nnz:
+        counts[p] += corr
+
+        if random:
+            # randomize crow_indices by shuffling the sawteeth
+            # sequence:
+            perm = torch.randperm(n_rows, device=counts.device)
+            counts[1:] = counts[1:][perm]
+
+        # compute crow_indices:
+        crow_indices = counts
+        crow_indices.cumsum_(dim=0)
+        return crow_indices.to(device=device)
+
+    def genSparseCompressedTensor(self, size, nnz, *, layout, device, dtype, index_dtype, blocksize=(), dense_dims=0):
+        from operator import mul
+        from functools import reduce
+        sparse_dim = 2
+        assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
+        assert len(size) >= sparse_dim
+        if blocksize:
+            assert len(blocksize) == 2, (size, blocksize)
+            assert size[-2 - dense_dims] % blocksize[0] == 0, (size, blocksize)
+            assert size[-1 - dense_dims] % blocksize[1] == 0, (size, blocksize)
+            blocksize0, blocksize1 = blocksize
+        else:
+            blocksize0 = blocksize1 = 1
+
+        size = tuple(size)
+        dense_size = size[(len(size) - dense_dims):]
+
+        def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz):
+            compressed_indices = self._make_crow_indices(n_compressed_dims, n_plain_dims, nnz, device=device, dtype=index_dtype)
+            plain_indices = torch.zeros(nnz, dtype=index_dtype, device=device)
+            for i in range(n_compressed_dims):
+                count = compressed_indices[i + 1] - compressed_indices[i]
+                plain_indices[compressed_indices[i]:compressed_indices[i + 1]], _ = torch.sort(
+                    torch.randperm(n_plain_dims, dtype=index_dtype, device=device)[:count])
+            low = -1 if dtype != torch.uint8 else 0
+            high = 1 if dtype != torch.uint8 else 2
+            values = make_tensor((nnz,) + blocksize + dense_size, device=device, dtype=dtype, low=low, high=high)
+            return values, compressed_indices, plain_indices
+
+        batch_shape = size[:-2 - dense_dims]
+        n_batch = reduce(mul, batch_shape, 1)
+
+        if layout in {torch.sparse_csr, torch.sparse_bsr}:
+            n_compressed_dims, n_plain_dims = size[-2 - dense_dims] // blocksize0, size[-1 - dense_dims] // blocksize1
+        else:
+            n_compressed_dims, n_plain_dims = size[-1 - dense_dims] // blocksize1, size[-2 - dense_dims] // blocksize0
+        blocknnz = nnz // (blocksize0 * blocksize1)
+        sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, blocknnz) for _ in range(n_batch)]
+        sparse_tensors_it = map(list, zip(*sparse_tensors, strict=True))
+
+        values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, blocknnz, *blocksize, *dense_size)
+        compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
+        plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
+        return torch.sparse_compressed_tensor(compressed_indices, plain_indices,
+                                              values, size=size, dtype=dtype, layout=layout, device=device)
+
+    def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csr, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=dense_dims)
+
+    def genSparseCSCTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csc, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=0)
+
+    def genSparseBSRTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        assert len(blocksize) == 2
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsr, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
+
+    def genSparseBSCTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        assert len(blocksize) == 2
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsc, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
+
+    def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype):
+        # Assert not given impossible combination, where the sparse dims have
+        # empty numel, but nnz > 0 makes the indices containing values.
+        assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
+
+        v_size = [nnz] + list(size[sparse_dim:])
+        v = make_tensor(v_size, device=device, dtype=dtype, low=-1, high=1)
+        i = torch.rand(sparse_dim, nnz, device=device)
+        i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
+        i = i.to(torch.long)
+        if is_uncoalesced:
+            i1 = i[:, :(nnz // 2), ...]
+            i2 = i[:, :((nnz + 1) // 2), ...]
+            i = torch.cat([i1, i2], 1)
+        x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device)
+
+        if not is_uncoalesced:
+            x = x.coalesce()
+        else:
+            # FIXME: `x` is a sparse view of `v`. Currently rebase_history for
+            #        sparse views is not implemented, so this workaround is
+            #        needed for inplace operations done on `x`, e.g., copy_().
+            #        Remove after implementing something equivalent to CopySlice
+            #        for sparse views.
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards
+            x = x.detach().clone()._coalesced_(False)
+        return x, x._indices().clone(), x._values().clone()
+
+    def generate_simple_inputs(self, layout,
+                               device=None,
+                               dtype=None,
+                               index_dtype=None,
+                               pin_memory=None,
+                               members_pin_memory=None,
+                               enable_batch=True,
+                               enable_hybrid=True,
+                               enable_zero_sized=True,
+                               enable_non_contiguous_indices=True,
+                               enable_non_contiguous_values=True,
+                               enable_batch_variable_nse=False,
+                               output_tensor=True,
+                               patterns=None):
+        """Generator of simple inputs for tensor constructors of the given layout.
+
+        The generated tensor inputs have the following properties:
+
+        - tensor shapes are minimal but not trivial
+        - tensor values are sorted sequences for COO and CSR formats, e.g. [1, 2, 3, 4]
+        - the generated tensors represent the same mathematical tensor for all layouts
+        - the generated tensors include regular, zero-sized, and optionally, batched or/and hybrid tensors.
+        - the generated tensors include contiguous or non-contiguous tensors both in indices and values
+
+        If output_tensor is True, yield tensors with the given
+        layout. Otherwise, yield inputs to the corresponding tensor
+        constructors:
+
+          - sparse compressed input is defined as
+            (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype,
+                                                              pin_memory=pin_memory)
+
+          - sparse COO input is defined as
+            (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory)
+
+          - strided input is defined as
+            (values,), dict(device=device, dtype=dtype)
+        """
+        if index_dtype is None:
+            index_dtype = torch.int64
+
+        is_compressed_sparse_layout = layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
+
+        if output_tensor:
+            for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype,
+                                                            pin_memory=pin_memory,
+                                                            enable_batch=enable_batch, enable_hybrid=enable_hybrid,
+                                                            enable_zero_sized=enable_zero_sized,
+                                                            enable_non_contiguous_indices=enable_non_contiguous_indices,
+                                                            enable_non_contiguous_values=enable_non_contiguous_values,
+                                                            enable_batch_variable_nse=enable_batch_variable_nse,
+                                                            output_tensor=False):
+                if members_pin_memory:
+                    args = tuple(a.pin_memory() for a in args)
+                if layout is torch.strided:
+                    assert len(args) == 1
+                    size = kwargs.pop('size', None)  # to ensure that a zero-sized tensor has the desired shape
+                    assert size is not None
+                    if pin_memory:
+                        yield args[0].reshape(size).pin_memory()
+                    else:
+                        yield args[0].reshape(size)
+                elif layout is torch.sparse_coo:
+                    yield torch.sparse_coo_tensor(*args, **kwargs)
+                elif is_compressed_sparse_layout:
+                    kwargs.update(layout=layout)
+                    yield torch.sparse_compressed_tensor(*args, **kwargs)
+                else:
+                    assert 0  # unreachable
+            return
+
+        def get_blockpattern(pattern, blocksize):
+            basesize = pattern.shape
+            assert basesize[0] % blocksize[0] == 0, (basesize, blocksize)
+            assert basesize[1] % blocksize[1] == 0, (basesize, blocksize)
+            blockpattern = pattern.reshape(-1,
+                                           blocksize[0],
+                                           basesize[1] // blocksize[1],
+                                           blocksize[1]).transpose(-3, -2).any(-1).any(-1)
+            block_ids = torch.arange(1, blockpattern.numel() + 1).reshape(blockpattern.shape)
+            return (blockpattern != 0) * block_ids
+
+        def get_sparse_data(pattern):
+            basesize = pattern.shape
+            assert len(basesize) == 2, basesize  # pattern is expected to be a matrix
+
+            # We cannot use `torch.sparse_xyz_tensor(pattern)` to
+            # compute the sparse layout indices and values because
+            # generate_simple_inputs is used to generate the inputs to
+            # test `torch.sparse_xyz_tensor` factory functions, so
+            # we'll compute the indices and values independently of
+            # the factory functions.
+
+            indices = torch.where(pattern != 0)
+            coo_indices = torch.stack(indices)
+            crow_indices = torch.zeros(basesize[0] + 1, dtype=torch.int64)
+            crow_indices[1:] = torch.cumsum(coo_indices[0].bincount(minlength=basesize[0]), 0)
+            col_indices = coo_indices[1]
+            strided_values = torch.zeros(basesize, dtype=torch.int64)
+
+            # the property of `values == range(1, 1+nnz)` is used in
+            # get_sparse_data_with_block to relate BSR and BSC values,
+            # so, don't change the following line:
+            values = torch.arange(1, 1 + len(indices[0]), dtype=torch.int64)
+            strided_values[indices] = values
+
+            indices_T = torch.where(pattern.transpose(0, 1) != 0)
+            coo_indices_T = torch.stack(indices_T)
+            ccol_indices = torch.zeros(basesize[1] + 1, dtype=torch.int64)
+            ccol_indices[1:] = torch.cumsum(coo_indices_T[0].bincount(minlength=basesize[1]), 0)
+            row_indices = coo_indices_T[1]
+            csc_values = strided_values.transpose(0, 1)[indices_T]
+
+            return {torch.sparse_coo: (coo_indices, values),
+                    torch.sparse_csr: (crow_indices, col_indices, values),
+                    torch.sparse_csc: (ccol_indices, row_indices, csc_values),
+                    torch.strided: (strided_values,)}
+
+        def get_sparse_data_with_block(pattern, blocksize):
+            nonblock_data = get_sparse_data(pattern)
+            blockpattern = get_blockpattern(pattern, blocksize)
+            block_data = get_sparse_data(blockpattern)
+
+            strided_values = nonblock_data[torch.strided][0]
+            block_indices = block_data[torch.sparse_coo][0]
+            bsr_values = torch.stack([strided_values[bi * blocksize[0]:(bi + 1) * blocksize[0],
+                                                     bj * blocksize[1]:(bj + 1) * blocksize[1]]
+                                      for bi, bj in block_indices.transpose(0, 1)])
+
+            # here we use the property `values == range(1, 1+nnz)` and
+            # `values` relation to `csc_values` (see get_sparse_data)
+            # to get BSC blocks via reordering the BSR blocks:
+            bsc_values = bsr_values[block_data[torch.sparse_csc][2] - 1]
+
+            return {torch.sparse_bsr: (*block_data[torch.sparse_csr][:2], bsr_values),
+                    torch.sparse_bsc: (*block_data[torch.sparse_csc][:2], bsc_values),
+                    **nonblock_data}
+
+        def get_batch_sparse_data(pattern, blocksize):
+            size = pattern.shape
+            if len(size) <= 2:  # non-batch
+                return get_sparse_data_with_block(pattern, blocksize)
+
+            # batch data is created recursively:
+            batch_data = {}  # type: ignore[var-annotated]
+            for i, item in enumerate(pattern):
+                for layout, d in get_batch_sparse_data(item, blocksize).items():
+                    target = batch_data.get(layout)
+                    if layout is torch.sparse_coo:
+                        # a "batch COO" means a COO with the leading
+                        # sparse dimensions interpreted as batch
+                        # dimensions
+                        ext_coo_indices1 = torch.cat((torch.full((1, len(d[1])), i, dtype=torch.int64), d[0]))
+                        if target is None:
+                            target = batch_data[layout] = (ext_coo_indices1, d[1])
+                        else:
+                            target[0].set_(torch.cat((target[0], ext_coo_indices1), 1))  # type: ignore[call-overload]
+                            target[1].set_(torch.cat((target[1], d[1])))
+                    else:
+                        if target is None:
+                            target = batch_data[layout] = tuple(d[j].unsqueeze(0) for j in range(len(d)))
+                        else:
+                            for j in range(len(d)):
+                                target[j].set_(torch.cat((target[j], d[j].unsqueeze(0))))  # type: ignore[call-overload]
+            return batch_data
+
+        def generate_values(base, densesize):
+            """Generates a tensor of shape densesize with values equal to
+
+              base + i_1 * 10^0 + ... + i_d * 10^{d - 1}
+
+            at indices i_1, ..., i_d (with 0 <= i_j < densesize[j] for any 1 <= j <=
+            len(densesize))
+
+            This mapping produces unique values as long as
+            densesize[i] < 10 for all i in range(len(densesize)).
+            """
+
+            if not densesize:
+                return base
+            if not isinstance(base, int) and base.ndim > 0:
+                return torch.stack([generate_values(b, densesize) for b in base])
+            if base == 0:
+                return torch.zeros(densesize, dtype=torch.int64)
+            r = torch.arange(densesize[0], dtype=torch.int64)
+            for i, d in enumerate(densesize[1:]):
+                y = torch.arange(d, dtype=torch.int64) * (10 ** (i + 1))
+                r = r[..., None] + y[None, ...]
+            r.add_(base)
+            return r
+
+        if patterns is None:
+            # A pattern is a 3-tuple with the following items:
+            #
+            # - a list of integers with the depth of two or more. The
+            #   integers define the sparsity patterns of the generated
+            #   inputs: zero values correspond to unspecified
+            #   elements/blocks, and non-zero values to the specified
+            #   elements.
+            #
+            #   For debugging convenience, the elements with the same
+            #   value typically belong to the same block. However, it
+            #   is not a hard requirement: as long as the shape of a
+            #   pattern divides with block sizes, the pattern will be
+            #   a valid one.
+            #
+            #   If the depth of the list is larger than two, inputs
+            #   with batch dimensions will be generated.
+            #
+            # - a list of 2-tuples of block sizes, used to generate
+            #   BSR/BSC tensors with various block size parameters
+            #
+            # - a list of tuples of dense dimensions, used to generate
+            #   hybrid tensors with various dense dimensions
+            #
+            patterns = [
+                # a simple 3 x 2 tensor: non-hybrid, hybrid with 1 and 2 dense dimensions
+                ([[1, 2, 0],
+                  [1, 0, 3]], [(2, 1), (1, 3)], [(), (2,), (4, 5)]),
+                # 2 x 3 batch of 3 x 2 tensors: non-hybrid and hybrid with 2 dense dimensions
+                ([[[[1, 2, 0],
+                    [1, 0, 3]],
+                   [[1, 2, 3],
+                    [1, 0, 0]],
+                   [[1, 0, 0],
+                    [1, 2, 3]]],
+                  [[[0, 2, 0],
+                    [1, 2, 3]],
+                   [[1, 0, 3],
+                    [1, 2, 0]],
+                   [[1, 2, 3],
+                    [0, 2, 0]]]], [(2, 1), (2, 3)], [(), (2,)]),
+                # tensor with non-trivial blocksize
+                ([[0, 1, 0, 2, 0, 2],
+                  [0, 1, 0, 0, 2, 0],
+                  [3, 3, 3, 0, 0, 0],
+                  [0, 0, 0, 0, 0, 0],
+                  [0, 5, 0, 6, 6, 6],
+                  [5, 0, 5, 6, 6, 6],
+                  [0, 0, 0, 0, 8, 8],
+                  [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 5)]),
+                # batch tensor with variable NSE
+                # Requires https://github.com/pytorch/pytorch/pull/84843 or similar.
+                ([[[1, 2],
+                   [3, 4]],
+                  [[1, 0],
+                   [0, 0]]], [(1, 1)], ([()] if enable_batch_variable_nse else []))]
+
+        def non_contiguous_copy(t, dim=-1, offset=0):
+            # return a copy of t that is non-contiguous along the
+            # given dimension and with the given storage offset
+            self.assertTrue(t.is_contiguous())
+            if dim < 0:
+                dim = dim + t.ndim
+            assert dim >= 0 and dim < t.ndim
+            step = max(2, offset + 1)
+            tmp = torch.zeros((*t.shape[:dim], t.shape[dim] * step, *t.shape[dim + 1:]), dtype=t.dtype, device=t.device)
+            dim_slices = (*((slice(None),) * dim), slice(offset, None, step))
+            r = tmp[dim_slices].copy_(t)
+            self.assertFalse(r.is_contiguous())
+            self.assertEqual(t, r)
+            return r
+
+        # the main loop of the method:
+        for pattern, blocksizes, densesizes in patterns:
+            if not enable_hybrid:
+                densesizes = [s for s in densesizes if not s]
+            if not (densesizes and blocksizes):
+                continue
+            pattern = torch.tensor(pattern, dtype=torch.int64)
+            if not enable_batch and pattern.ndim > 2:
+                continue
+            for blocksize in blocksizes:
+                data = get_batch_sparse_data(pattern, blocksize)[layout]
+                for densesize in densesizes:
+                    indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]]
+                    values = generate_values(data[-1], densesize).to(device=device, dtype=dtype)
+                    kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize)
+                    if pin_memory is not None:
+                        kwargs.update(pin_memory=pin_memory)
+
+                    yield (*indices, values), kwargs.copy()
+                    if enable_non_contiguous_indices and pattern.ndim > 2:
+                        # sparse compressed indices can be sliced only along batch dimensions
+                        for (dim, offset) in {(0, 1), (-2, 0)}:
+                            indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices]
+                            yield (*indices_copy, values), kwargs.copy()
+
+                            if enable_non_contiguous_values:
+                                values_copy = non_contiguous_copy(values, dim=-1, offset=1)
+                                yield (*indices_copy, values_copy), kwargs.copy()
+
+                    if enable_non_contiguous_values:
+                        values_copy = non_contiguous_copy(values, dim=-1, offset=1)
+                        yield (*indices, values_copy), kwargs.copy()
+
+        # zero-sized tensor inputs, non-batch, non-hybrid/hybrid
+        if enable_zero_sized:
+            for basesize, blocksizes, densesizes in [
+                    ((2, 0), [(1, 2)], [(), (2,), (2, 3)] if enable_hybrid else [()]),
+                    ((0, 2), [(1, 2), (2, 1), (3, 2)], [()]),
+                    ((0, 0), [(1, 2)], [()]),
+            ]:
+                for blocksize in blocksizes:
+                    for densesize in densesizes:  # type: ignore[attr-defined]
+                        if layout == torch.strided:
+                            indices = ()  # type: ignore[assignment]
+                            values = torch.empty((basesize + densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_coo:
+                            indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_csr:
+                            crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype)
+                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (crow_indices, col_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_csc:
+                            ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype)
+                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (ccol_indices, row_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_bsr:
+                            crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype)
+                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (crow_indices, col_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_bsc:
+                            ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype)
+                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (ccol_indices, row_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
+                        else:
+                            assert 0  # unreachable
+                        kwargs = dict(device=device, dtype=dtype, size=basesize + densesize)
+                        if pin_memory is not None:
+                            kwargs.update(pin_memory=pin_memory)
+                        yield (*indices, values), kwargs
+
+    def safeToDense(self, t):
+        # coalesce is only implemented for COO
+        if t.layout == torch.sparse_coo:
+            t = t.coalesce()
+        return t.to_dense()
+
+    # Compares a torch function with a reference function for a given sample input (object of SampleInput)
+    # Note: only values are compared, type comparison is not done here
+    def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs):
+        numpy_sample = sample_input.numpy()
+        n_inp, n_args, n_kwargs = numpy_sample.input, numpy_sample.args, numpy_sample.kwargs
+        t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
+
+        actual = torch_fn(t_inp, *t_args, **t_kwargs)
+        expected = ref_fn(n_inp, *n_args, **n_kwargs)
+
+        self.assertEqual(actual, expected, exact_device=False, **kwargs)
+
+    # Compares the given Torch and NumPy functions on the given tensor-like object.
+    # NOTE: both torch_fn and np_fn should be functions that take a single
+    #   tensor (array). If the torch and/or NumPy function require additional
+    #   arguments then wrap the function in a lambda or pass a partial function.
+    # TODO: add args/kwargs for passing to assertEqual (e.g. rtol, atol)
+    def compare_with_numpy(self, torch_fn, np_fn, tensor_like,
+                           device=None, dtype=None, **kwargs):
+        assert TEST_NUMPY
+
+        if isinstance(tensor_like, torch.Tensor):
+            assert device is None
+            assert dtype is None
+            t_cpu = tensor_like.detach().cpu()
+            if t_cpu.dtype is torch.bfloat16:
+                t_cpu = t_cpu.float()
+            a = t_cpu.numpy()
+            t = tensor_like
+        else:
+            d = copy.copy(torch_to_numpy_dtype_dict)
+            d[torch.bfloat16] = np.float32
+            a = np.array(tensor_like, dtype=d[dtype])
+            t = torch.tensor(tensor_like, device=device, dtype=dtype)
+
+        np_result = np_fn(a)
+        torch_result = torch_fn(t).cpu()
+
+        # Converts arrays to tensors
+        if isinstance(np_result, np.ndarray):
+            try:
+                np_result = torch.from_numpy(np_result)
+            except Exception:
+                # NOTE: copying an array before conversion is necessary when,
+                #   for example, the array has negative strides.
+                np_result = torch.from_numpy(np_result.copy())
+            if t.dtype is torch.bfloat16 and torch_result.dtype is torch.bfloat16 and np_result.dtype is torch.float:
+                torch_result = torch_result.to(torch.float)
+
+        self.assertEqual(np_result, torch_result, **kwargs)
+
+    def assertEqualIgnoreType(self, *args, **kwargs) -> None:
+        # If you are seeing this function used, that means test is written wrongly
+        # and deserves detailed investigation
+        return self.assertEqual(*args, exact_dtype=False, **kwargs)
+
+    def assertEqualBroadcasting(self, x, y, *args, **kwargs) -> None:
+        r"""Tests if tensor x equals to y, if y to be broadcast to x.shape.
+        """
+        if not isinstance(y, Iterable):
+            # int, float, etc. or different shape tensors
+            y = torch.ones_like(x) * y
+        if not isinstance(y, torch.Tensor):
+            # iterable, but not a tensor
+            y = torch.ones_like(x) * torch.tensor(y)
+        return self.assertEqual(x, y, *args, **kwargs)
+
+    def assertEqual(
+            self,
+            x,
+            y,
+            msg: Optional[Union[str, Callable[[str], str]]] = None,
+            *,
+            atol: Optional[float] = None,
+            rtol: Optional[float] = None,
+            equal_nan=True,
+            exact_dtype=True,
+            # TODO: default this to True
+            exact_device=False,
+            exact_layout=False,
+            exact_stride=False,
+            exact_is_coalesced=False
+    ):
+        # Hide this function from `pytest`'s traceback
+        __tracebackhide__ = True
+
+        # numpy's dtypes are a superset of what PyTorch supports. In case we encounter an unsupported dtype, we fall
+        # back to an elementwise comparison. Note that this has to happen here and not for example in
+        # `TensorOrArrayPair`, since at that stage we can no longer split the array into its elements and perform
+        # multiple comparisons.
+        if any(
+            isinstance(input, np.ndarray) and not has_corresponding_torch_dtype(input.dtype) for input in (x, y)
+        ):
+            def to_list(input):
+                return input.tolist() if isinstance(input, (torch.Tensor, np.ndarray)) else list(input)
+
+            x = to_list(x)
+            y = to_list(y)
+        # When comparing a sequence of numbers to a tensor, we need to convert the sequence to a tensor here.
+        # Otherwise, the pair origination of `are_equal` will fail, because the sequence is recognized as container
+        # that should be checked elementwise while the tensor is not.
+        elif isinstance(x, torch.Tensor) and isinstance(y, Sequence):
+            y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
+        elif isinstance(x, Sequence) and isinstance(y, torch.Tensor):
+            x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
+
+        # unbind NSTs to compare them; don't do this for NJTs
+        if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided:
+            x = x.unbind()
+        if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided:
+            y = y.unbind()
+
+        error_metas = not_close_error_metas(
+            x,
+            y,
+            pair_types=(
+                NonePair,
+                RelaxedBooleanPair,
+                RelaxedNumberPair,
+                TensorOrArrayPair,
+                TypedStoragePair,
+                StringPair,
+                SetPair,
+                TypePair,
+                ObjectPair,
+            ),
+            sequence_types=(
+                Sequence,
+                Sequential,
+                ModuleList,
+                ParameterList,
+                ScriptList,
+                torch.utils.data.dataset.Subset,
+            ),
+            mapping_types=(Mapping, ModuleDict, ParameterDict, ScriptDict),
+            rtol=rtol,
+            rtol_override=self.rel_tol,
+            atol=atol,
+            atol_override=self.precision,
+            equal_nan=equal_nan,
+            check_device=exact_device,
+            check_dtype=exact_dtype,
+            check_layout=exact_layout,
+            check_stride=exact_stride,
+            check_is_coalesced=exact_is_coalesced,
+        )
+
+        if error_metas:
+            # See [ErrorMeta Cycles]
+            error_metas = [error_metas]  # type: ignore[list-item]
+            # TODO: compose all metas into one AssertionError
+            raise error_metas.pop()[0].to_error(  # type: ignore[index]
+                # This emulates unittest.TestCase's behavior if a custom message passed and
+                # TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage)
+                # is True (default)
+                (lambda generated_msg: f"{generated_msg}\n{msg}") if isinstance(msg, str) and self.longMessage else msg
+            )
+
+    def assertNotEqual(self, x, y, msg: Optional[str] = None, *,                                       # type: ignore[override]
+                       atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None:
+        with self.assertRaises(AssertionError, msg=msg):
+            self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs)
+
+    def assertEqualTypeString(self, x, y) -> None:
+        # This API is used simulate deprecated x.type() is y.type()
+        self.assertEqual(x.device, y.device)
+        self.assertEqual(x.dtype, y.dtype)
+        self.assertEqual(x.is_sparse, y.is_sparse)
+
+    def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None:
+        for elem in iterable:
+            if id(obj) == id(elem):
+                return
+        raise AssertionError("object not found in iterable")
+
+    # Reimplemented to provide special behavior when
+    # _ignore_not_implemented_error is True
+    def assertRaises(self, expected_exception, *args, **kwargs):
+        if self._ignore_not_implemented_error:
+            context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \
+                AssertRaisesContextIgnoreNotImplementedError(expected_exception, self)  # type: ignore[call-arg]
+            try:
+                return context.handle('assertRaises', args, kwargs)  # type: ignore[union-attr, arg-type]
+            finally:
+                # see https://bugs.python.org/issue23890
+                context = None
+        else:
+            return super().assertRaises(expected_exception, *args, **kwargs)
+
+    # Reimplemented to provide special behavior when
+    # _ignore_not_implemented_error is True
+    def assertRaisesRegex(self, expected_exception, expected_regex, *args, **kwargs):
+        # Verifies that an exception with the type expected_exception and message
+        # matching the regular expression defined by expected_regex is thrown.
+        # If the test is instantiated for a non-native device type (like XLA)
+        # then the message is not validated.
+
+        # Checks whether the test is instantiated for a device type by testing
+        # if the test class has defined the device_type attribute and,
+        # if so, tests whether the instantiated device type is native or not
+        if hasattr(self, 'device_type') and self.device_type not in NATIVE_DEVICES and self.device_type != "mps":  # type: ignore[attr-defined]
+            # empty string matches any string
+            expected_regex = ''
+
+        if self._ignore_not_implemented_error:
+            context = AssertRaisesContextIgnoreNotImplementedError(  # type: ignore[call-arg]
+                expected_exception, self, expected_regex)
+            return context.handle('assertRaisesRegex', args, kwargs)  # type: ignore[attr-defined, arg-type]
+        else:
+            return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
+
+    # Verifies that no unraisable exceptions are raised by callable.  Unlike regular
+    # exceptions, these do not actually propagate to the caller and are
+    # suppressed.  We must test for them specially.
+    def assertNoUnraisable(self, callable, *args, **kwargs):
+        raised = None
+
+        def record_unraisable(unraisable):
+            nonlocal raised
+            raised = unraisable
+
+        # Disable GC when running the callable to prevent spurious flakiness
+        # from unlucky GCs inside the callable
+        prev = gc.isenabled()
+        gc.disable()
+        try:
+            with unittest.mock.patch("sys.unraisablehook", record_unraisable):
+                callable(*args, **kwargs)
+        finally:
+            if prev:
+                gc.enable()
+
+        self.assertIsNone(raised)
+
+    # TODO: Support context manager interface
+    # NB: The kwargs forwarding to callable robs the 'subname' parameter.
+    # If you need it, manually apply your callable in a lambda instead.
+    def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
+        subname = None
+        if 'subname' in kwargs:
+            subname = kwargs['subname']
+            del kwargs['subname']
+        try:
+            callable(*args, **kwargs)
+        except exc_type as e:
+            self.assertExpected(str(e), subname)
+            return
+        # Don't put this in the try block; the AssertionError will catch it
+        self.fail(msg="Did not raise when expected to")
+
+    def assertNotWarn(self, callable, msg=''):
+        r"""
+        Test if :attr:`callable` does not raise a warning.
+        """
+        with warnings.catch_warnings(record=True) as ws:
+            warnings.simplefilter("always")  # allow any warning to be raised
+            with set_warn_always_context(True):
+                callable()
+            self.assertTrue(len(ws) == 0, msg)
+
+    @contextmanager
+    def assertWarnsOnceRegex(self, category, regex=''):
+        """Context manager for code that *must always* warn
+
+        This filters expected warnings from the test and fails if
+        the expected warning is not caught. It uses set_warn_always() to force
+        TORCH_WARN_ONCE to behave like TORCH_WARN
+        """
+        pattern = re.compile(regex)
+        with warnings.catch_warnings(record=True) as ws:
+            warnings.simplefilter("always")  # allow any warning to be raised
+            with set_warn_always_context(True):
+                yield
+            if len(ws) == 0:
+                self.fail('no warning caught')
+            self.assertTrue(any(type(w.message) is category for w in ws))
+            self.assertTrue(
+                any(re.match(pattern, str(w.message)) for w in ws),
+                f'{pattern}, {[w.message for w in ws if type(w.message) is category]}')
+
+    def assertExpected(self, s, subname=None):
+        r"""
+        Test that a string matches the recorded contents of a file
+        derived from the name of this test and subname.  This file
+        is placed in the 'expect' directory in the same directory
+        as the test script. You can automatically update the recorded test
+        output using --accept.
+
+        If you call this multiple times in a single function, you must
+        give a unique subname each time.
+        """
+        if not isinstance(s, str):
+            raise TypeError("assertExpected is strings only")
+
+        def remove_prefix(text, prefix):
+            if text.startswith(prefix):
+                return text[len(prefix):]
+            return text
+        # NB: we take __file__ from the module that defined the test
+        # class, so we place the expect directory where the test script
+        # lives, NOT where test/common_utils.py lives.  This doesn't matter in
+        # PyTorch where all test scripts are in the same directory as
+        # test/common_utils.py, but it matters in onnx-pytorch
+        module_id = self.__class__.__module__
+        munged_id = remove_prefix(self.id(), module_id + ".")
+        test_file = os.path.realpath(sys.modules[module_id].__file__)  # type: ignore[type-var]
+        expected_file = os.path.join(os.path.dirname(test_file),  # type: ignore[type-var, arg-type]
+                                     "expect",
+                                     munged_id)
+
+        subname_output = ""
+        if subname:
+            expected_file += "-" + subname
+            subname_output = f" ({subname})"
+        expected_file += ".expect"
+        expected = None
+
+        def accept_output(update_type):
+            print(f"Accepting {update_type} for {munged_id}{subname_output}:\n\n{s}")
+            with open(expected_file, 'w') as f:
+                # Adjust for producer_version, leave s unmodified
+                s_tag = re.sub(r'(producer_version): "[0-9.]*"',
+                               r'\1: "CURRENT_VERSION"', s)
+                f.write(s_tag)
+
+        try:
+            with open(expected_file) as f:
+                expected = f.read()
+        except OSError as e:
+            if e.errno != errno.ENOENT:
+                raise
+            elif expecttest.ACCEPT:
+                return accept_output("output")
+            else:
+                raise RuntimeError(
+                      f"I got this output for {munged_id}{subname_output}:\n\n{s}\n\n"
+                      "No expect file exists; to accept the current output, run:\n"
+                      f"python {__main__.__file__} {munged_id} --accept") from None
+
+        # a hack for JIT tests
+        if IS_WINDOWS:
+            expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
+            s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)
+
+        # Adjust for producer_version
+        expected = expected.replace(
+            'producer_version: "CURRENT_VERSION"',
+            f'producer_version: "{torch.onnx.producer_version}"'
+        )
+        if expecttest.ACCEPT:
+            if expected != s:
+                return accept_output("updated output")
+        else:
+            if hasattr(self, "assertMultiLineEqual"):
+                # Python 2.7 only
+                # NB: Python considers lhs "old" and rhs "new".
+                self.assertMultiLineEqual(expected, s)
+            else:
+                self.assertEqual(s, expected)
+
+    def assertExpectedStripMangled(self, s, subname=None):
+        s = re.sub(r'__torch__[^ ]+', '', s)
+        self.assertExpected(s, subname)
+
+    def assertGreaterAlmostEqual(self, first, second, places=None, msg=None, delta=None):
+        """Assert that ``first`` is greater than or almost equal to ``second``.
+
+        The equality of ``first`` and ``second`` is determined in a similar way to
+        the ``assertAlmostEqual`` function of the standard library.
+        """
+        if delta is not None and places is not None:
+            raise TypeError("specify delta or places not both")
+
+        if first >= second:
+            return
+
+        diff = second - first
+        if delta is not None:
+            if diff <= delta:
+                return
+
+            standardMsg = f"{first} not greater than or equal to {second} within {delta} delta"
+        else:
+            if places is None:
+                places = 7
+
+            if round(diff, places) == 0:
+                return
+
+            standardMsg = f"{first} not greater than or equal to {second} within {places} places"
+
+        msg = self._formatMessage(msg, standardMsg)
+        raise self.failureException(msg)
+
+    def assertAtenOp(self, onnx_model, operator, overload_name=""):
+        all_aten_nodes = [p for p in onnx_model.graph.node
+                          if p.op_type == "ATen" and p.domain == "org.pytorch.aten"]
+        self.assertTrue(all_aten_nodes)
+
+        for op in all_aten_nodes:
+            attrs = {attr.name: attr.s.decode() for attr in op.attribute}
+            if attrs.get("operator") == operator:
+                break
+
+        self.assertEqual(attrs["operator"], operator)  # type: ignore[possibly-undefined]
+        self.assertEqual(attrs.get("overload_name", ""), overload_name)
+
+    def check_nondeterministic_alert(self, fn, caller_name, should_alert=True):
+        '''Checks that an operation produces a nondeterministic alert when
+        expected while `torch.use_deterministic_algorithms(True)` is set.
+
+        Args:
+          fn (callable): Function to check for a nondeterministic alert
+
+          caller_name (str): Name of the operation that produces the
+              nondeterministic alert. This name is expected to appear at the
+              beginning of the error/warning message.
+
+          should_alert (bool, optional): If True, then the check will only pass
+              if calling `fn` produces a nondeterministic error/warning with the
+              expected message. If False, then the check will only pass if
+              calling `fn` does not produce an error. Default: `True`.
+        '''
+
+        alert_message = '^' + caller_name + ' does not have a deterministic implementation, but you set'
+
+        # Check that errors are thrown correctly
+        with DeterministicGuard(True):
+            if should_alert:
+                with self.assertRaisesRegex(
+                        RuntimeError,
+                        alert_message,
+                        msg='expected a non-deterministic error, but it was not raised'):
+                    fn()
+
+            else:
+                # If a nondeterministic error is not expected, make sure
+                # that it is not raised
+                try:
+                    fn()
+                except RuntimeError as e:
+                    if 'does not have a deterministic implementation' in str(e):
+                        self.fail(
+                            'did not expect non-deterministic error message, '
+                            + 'but got one anyway: "' + str(e) + '"')
+                    # Reraise exceptions unrelated to nondeterminism
+                    raise
+
+        # Check that warnings are thrown correctly
+        with DeterministicGuard(True, warn_only=True):
+            if should_alert:
+                with self.assertWarnsRegex(
+                        UserWarning,
+                        alert_message):
+                    fn()
+            else:
+                with warnings.catch_warnings(record=True) as w:
+                    warnings.simplefilter("always")
+                    fn()
+                    for warning in w:
+                        if isinstance(warning, UserWarning):
+                            self.assertTrue(re.search(alert_message, str(warning)) is None)
+
+    # run code in subprocess and capture exceptions.
+    @staticmethod
+    def run_process_no_exception(code, env=None):
+        import subprocess
+
+        with subprocess.Popen(
+            [sys.executable, "-c", code],
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            env=env,
+        ) as p:
+            (stdout, stderr) = p.communicate()
+            return (stdout, stderr)
+
+    # returns captured stderr
+    @staticmethod
+    def runWithPytorchAPIUsageStderr(code):
+        env = os.environ.copy()
+        env["PYTORCH_API_USAGE_STDERR"] = "1"
+        # remove CI flag since this is a wrapped test process.
+        # CI flag should be set in the parent process only.
+        env.pop("CI", None)
+        env.pop("TEST_SHOWLOCALS", None)
+        _stdout, stderr = TestCase.run_process_no_exception(code, env=env)
+        return stderr.decode('ascii')
+
+    def _attempt_load_from_subprocess(
+        self,
+        file: pathlib.Path,
+        import_string: str,
+        expected_failure_message: Optional[str] = None
+    ) -> None:
+        """
+        Attempts weights_only `torch.load` in a subprocess. This is used to test that
+        weights_only `torch.load` works as expected without global imports.
+
+        Args:
+            file (pathlib.Path): The path to the checkpoint to load.
+            import_string (str): import string to add to the script
+            exected_failure_message (str, optional): The expected failure message if the
+                checkpoint fails to load. If None, the test will pass
+        """
+        script = f"import torch;{import_string}torch.load(r'{file}', weights_only=True)"
+        cm = (
+            self.assertRaisesRegex(RuntimeError, re.escape(expected_failure_message))
+            if expected_failure_message else contextlib.nullcontext()
+        )
+        with cm:
+            try:
+                subprocess.check_output(
+                    [sys.executable, "-c", script],
+                    # On Windows, opening the subprocess with the default CWD makes `import torch`
+                    # fail, so just set CWD to this script's directory
+                    cwd=os.path.dirname(os.path.realpath(__file__)),
+                    stderr=subprocess.STDOUT,
+                )
+            except subprocess.CalledProcessError as e:
+                raise RuntimeError(e.output.decode("utf-8")) from None
+
+
+class TestCaseBase(TestCase):
+    # Calls to super() in dynamically created classes are a bit odd.
+    # See https://github.com/pytorch/pytorch/pull/118586 for more info
+    # Subclassing this class and then calling super(TestCaseBase) will run
+    # TestCase's setUp, tearDown etc functions
+    pass
+
+
+def download_file(url, binary=True):
+    from urllib.parse import urlsplit
+    from urllib import request, error
+
+    filename = os.path.basename(urlsplit(url)[2])
+    data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data'))
+    path = os.path.join(data_dir, filename)
+
+    if os.path.exists(path):
+        return path
+    try:
+        with request.urlopen(url, timeout=15) as f1, open(path, 'wb' if binary else 'w') as f2:
+            data = f1.read()
+            f2.write(data)
+        return path
+    except error.URLError as e:
+        msg = f"could not download test file '{url}'"
+        warnings.warn(msg, RuntimeWarning, stacklevel=2)
+        raise unittest.SkipTest(msg) from e
+
+def find_free_port():
+    """
+    Finds an available port and returns that port number.
+
+    NOTE: If this function is being used to allocate a port to Store (or
+    indirectly via init_process_group or init_rpc), it should be used
+    in conjunction with the `retry_on_connect_failures` decorator as there is a potential
+    race condition where the allocated port may become unavailable before it can be used
+    """
+    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        sock.bind(('localhost', 0))
+        _, port = sock.getsockname()
+        return port
+
+# Errors that we can get in c10d initialization for which we should retry tests for.
+ADDRESS_IN_USE = "Address already in use"
+CONNECT_TIMEOUT = "connect() timed out."
+
+def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)):
+    """Reruns a test if the test returns a RuntimeError and the exception
+    contains one of the strings in connect_errors."""
+    # This if block is executed when using this function as a decorator with arguments.
+    if func is None:
+        return partial(retry_on_connect_failures, connect_errors=connect_errors)
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        n_retries = 10
+        tries_remaining = n_retries
+        while True:
+            try:
+                return func(*args, **kwargs)
+            except RuntimeError as error:
+                if any(connect_error in str(error) for connect_error in connect_errors):
+                    tries_remaining -= 1
+                    if tries_remaining == 0:
+                        raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error
+                    time.sleep(random.random())
+                    continue
+                raise
+    return wrapper
+
+
+# Decorator to retry upon certain Exceptions.
+def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
+    def deco_retry(f):
+        @wraps(f)
+        def f_retry(*args, **kwargs):
+            mtries, mdelay = tries, delay
+            while mtries > 1:
+                try:
+                    return f(*args, **kwargs)
+                except ExceptionToCheck as e:
+                    msg = f"{e}, Retrying in {mdelay:d} seconds..."
+                    print(msg)
+                    time.sleep(mdelay)
+                    mtries -= 1
+            try:
+                return f(*args, **kwargs)
+            except ExceptionToCheck as e:
+                raise unittest.SkipTest(f"Skipping after {tries} consecutive {str(e)}") from e if skip_after_retries else e
+        return f_retry  # true decorator
+    return deco_retry
+
+
+# FIXME: modernize these to be consistent with make_tensor
+#   and review including them in torch.testing
+# Methods for matrix generation
+
+def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
+    assert rank <= l
+    A = torch.randn(l, l, dtype=dtype, device=device)
+    u, s, vh = torch.linalg.svd(A, full_matrices=False)
+    for i in range(l):
+        if i >= rank:
+            s[i] = 0
+        elif s[i] == 0:
+            s[i] = 1
+    return (u * s.to(dtype).unsqueeze(-2)) @ vh
+
+def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001):
+    """
+    Returns a random rectangular matrix (batch of matrices)
+    with singular values sampled from a Gaussian with
+    mean `mean` and standard deviation `sigma`.
+    The smaller the `sigma`, the better conditioned
+    the output matrix is.
+    """
+    primitive_dtype = {
+        torch.float: torch.float,
+        torch.double: torch.double,
+        torch.cfloat: torch.float,
+        torch.cdouble: torch.double
+    }
+    x = torch.rand(shape, dtype=dtype, device=device)
+    m = x.size(-2)
+    n = x.size(-1)
+    u, _, vh = torch.linalg.svd(x, full_matrices=False)
+    s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \
+        .sort(-1, descending=True).values.to(dtype)
+    return (u * s.unsqueeze(-2)) @ vh
+
+# Returns a noncontiguous (tensor with the same shape and values as t
+# The noncontiguous tensor is constructed such that elements in the innermost
+#   dimension are separated by zeros or (whenever possible) nans
+# TODO: consider more complicated noncontiguity schemes
+def noncontiguous_like(t):
+    # Short-circuits if t is already noncontiguous
+    if not t.is_contiguous():
+        return t
+
+    # Choose a "weird" value that won't be accessed
+    if t.dtype.is_floating_point or t.dtype.is_complex:
+        value = math.nan
+    elif t.dtype == torch.bool:
+        value = True
+    else:
+        value = 12
+
+    result = t.new_empty(t.shape + (2,))
+    result[..., 0] = value
+    result[..., 1] = t.detach()
+    result = result[..., 1]
+    result.requires_grad_(t.requires_grad)
+    return result
+
+# TODO: remove this (prefer make_symmetric_matrices below)
+def random_symmetric_matrix(l, *batches, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    A = (A + A.mT).div_(2)
+    return A
+
+# Creates a symmetric matrix or batch of symmetric matrices
+# Shape must be a square matrix or batch of square matrices
+def make_symmetric_matrices(*shape, device, dtype):
+    assert shape[-1] == shape[-2]
+    t = make_tensor(shape, device=device, dtype=dtype)
+    t = (t + t.mT).div_(2)
+    return t
+
+def random_hermitian_matrix(l, *batches, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    A = (A + A.mH).div_(2)
+    return A
+
+
+def random_symmetric_psd_matrix(l, *batches, **kwargs):
+    """
+    Returns a batch of random symmetric positive-semi-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_symmetric_psd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    return A @ A.mT
+
+
+def random_hermitian_psd_matrix(matrix_size, *batch_dims, dtype=torch.double, device='cpu'):
+    """
+    Returns a batch of random Hermitian positive-semi-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_hermitian_psd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device)
+    return A @ A.mH
+
+
+# TODO: remove this (prefer make_symmetric_pd_matrices below)
+def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
+                    dtype=dtype, device=device)
+    return torch.matmul(A, A.mT) \
+        + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5
+
+
+# Creates a symmetric positive-definite matrix or batch of
+#   such matrices
+def make_symmetric_pd_matrices(*shape, device, dtype):
+    assert shape[-1] == shape[-2]
+    t = make_tensor(shape, device=device, dtype=dtype)
+    i = torch.eye(shape[-1], device=device, dtype=dtype) * 1e-5
+    return t @ t.mT + i
+
+def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device):
+    """
+    Returns a batch of random Hermitian positive-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
+                    dtype=dtype, device=device)
+    return A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device)
+
+# Creates a full rank matrix with distinct singular values or
+#   a batch of such matrices
+def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False):
+    with torch.no_grad():
+        t = make_tensor(shape, device=device, dtype=dtype)
+        u, _, vh = torch.linalg.svd(t, full_matrices=False)
+        real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype
+        k = min(shape[-1], shape[-2])
+        # We choose the singular values to be "around one"
+        # This is to make the matrix well conditioned
+        # s = [2, 3, ..., k+1]
+        s = torch.arange(2, k + 2, dtype=real_dtype, device=device)
+        # s = [2, -3, 4, ..., (-1)^k k+1]
+        s[1::2] *= -1.
+        # 1 + 1/s so that the singular values are in the range [2/3, 3/2]
+        # This gives a condition number of 9/4, which should be good enough
+        s.reciprocal_().add_(1.)
+        # Note that the singular values need not be ordered in an SVD so
+        # we don't need need to sort S
+        x = (u * s.to(u.dtype)) @ vh
+    x.requires_grad_(requires_grad)
+    return x
+
+def random_matrix(rows, columns, *batch_dims, **kwargs):
+    """Return rectangular matrix or batches of rectangular matrices.
+
+    Parameters:
+      dtype - the data type
+      device - the device kind
+      singular - when True, the output will be singular
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    silent = kwargs.get("silent", False)
+    singular = kwargs.get("singular", False)
+    if silent and not torch._C.has_lapack:
+        return torch.ones(rows, columns, dtype=dtype, device=device)
+
+    A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device)
+    if A.numel() == 0:
+        return A
+    u, _, vh = torch.linalg.svd(A, full_matrices=False)
+    k = min(rows, columns)
+    s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device)
+    if singular:
+        # make matrix singular
+        s[k - 1] = 0
+        if k > 2:
+            # increase the order of singularity so that the pivoting
+            # in LU factorization will be non-trivial
+            s[0] = 0
+    return (u * s.unsqueeze(-2)) @ vh
+
+
+def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):
+    """Return rectangular matrix or batches of rectangular matrices with
+    given rank.
+    """
+    B = random_matrix(rows, rank, *batch_dims, **kwargs)
+    C = random_matrix(rank, columns, *batch_dims, **kwargs)
+    return B.matmul(C)
+
+
+def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor:
+    """Generate indices for a row x cols matrix, preferring at least one index per row if possible."""
+    indices = []  # type: ignore[var-annotated]
+    n_per_row = math.ceil(num_indices / rows)
+    col_indices = list(range(cols))
+
+    for r in range(rows):
+        # Note that this can yield overlapping indices
+        indices.extend((r, c) for c in random.choices(col_indices, k=n_per_row))
+
+    return torch.tensor(indices[:num_indices])
+
+
+def random_sparse_matrix(rows, columns, density=0.01, **kwargs):
+    """Return rectangular random sparse matrix within given density.
+
+    The density of the result approaches to given density as the size
+    of the matrix is increased and a relatively small value of density
+    is specified but higher than min(rows, columns)/(rows * columns)
+    for non-singular matrices.
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+
+    nonzero_elements = max(min(rows, columns), int(rows * columns * density))
+    indices = _generate_indices_prefer_all_rows(rows, columns, nonzero_elements)
+    values = torch.randn(nonzero_elements, dtype=dtype, device=device)
+
+    # ensure that the diagonal dominates
+    values *= torch.tensor([-float(i - j)**2 for i, j in indices], dtype=dtype, device=device).exp()
+    A = torch.sparse_coo_tensor(indices.t(), values, (rows, columns), device=device)
+    return A.coalesce()
+
+
+def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
+    """Return random sparse positive-definite matrix with given density.
+
+    The eigenvalues of the matrix are defined as::
+      arange(1, matrix_size+1)/matrix_size
+
+    Algorithm:
+      A = diag(arange(1, matrix_size+1)/matrix_size)
+      while :
+          
+          R = 
+          A = R^T A R
+    """
+    import math
+    torch = kwargs.get('torch', globals()['torch'])
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    data = {(i, i): float(i + 1) / matrix_size
+            for i in range(matrix_size)}
+
+
+    def multiply(data, N, i, j, cs, sn, left=True):
+        for k in range(N):
+            if left:
+                ik, jk = (k, i), (k, j)
+            else:
+                ik, jk = (i, k), (j, k)
+            aik, ajk = data.get(ik, 0), data.get(jk, 0)
+            aik, ajk = cs * aik + sn * ajk, -sn * aik + cs * ajk
+            if aik:
+                data[ik] = aik
+            else:
+                data.pop(ik, None)
+            if ajk:
+                data[jk] = ajk
+            else:
+                data.pop(jk, None)
+
+    target_nnz = density * matrix_size * matrix_size
+    while len(data) < target_nnz:
+        i = random.randint(0, matrix_size - 1)
+        j = random.randint(0, matrix_size - 1)
+        if i != j:
+            theta = random.uniform(0, 2 * math.pi)
+            cs = math.cos(theta)
+            sn = math.sin(theta)
+            multiply(data, matrix_size, i, j, cs, sn, left=True)
+            multiply(data, matrix_size, i, j, cs, sn, left=False)
+    icoords, jcoords, values = [], [], []
+    for (i, j), v in sorted(data.items()):
+        icoords.append(i)
+        jcoords.append(j)
+        values.append(v)
+    indices_tensor = torch.tensor([icoords, jcoords])
+    return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device)
+
+# FIXME: remove this by updating test suites using it
+def do_test_dtypes(self, dtypes, layout, device):
+    for dtype in dtypes:
+        if dtype != torch.float16:
+            out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
+            self.assertIs(dtype, out.dtype)
+            self.assertIs(layout, out.layout)
+            self.assertEqual(device, out.device)
+
+# FIXME: remove this by updating test suites using it
+def do_test_empty_full(self, dtypes, layout, device):
+    shape = torch.Size([2, 3])
+
+    def check_value(tensor, dtype, layout, device, value, requires_grad):
+        self.assertEqual(shape, tensor.shape)
+        self.assertIs(dtype, tensor.dtype)
+        self.assertIs(layout, tensor.layout)
+        self.assertEqual(tensor.requires_grad, requires_grad)
+        if tensor.is_cuda and device is not None:
+            self.assertEqual(device, tensor.device)
+        if value is not None:
+            fill = tensor.new(shape).fill_(value)
+            self.assertEqual(tensor, fill)
+
+    def get_int64_dtype(dtype):
+        module = '.'.join(str(dtype).split('.')[1:-1])
+        if not module:
+            return torch.int64
+        return operator.attrgetter(module)(torch).int64
+
+    default_dtype = torch.get_default_dtype()
+    check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False)
+    check_value(torch.full(shape, -5.), default_dtype, torch.strided, -1, None, False)
+    for dtype in dtypes:
+        for rg in {dtype.is_floating_point, False}:
+            int64_dtype = get_int64_dtype(dtype)
+            v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
+            check_value(v, dtype, layout, device, None, rg)
+            out = v.new()
+            check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
+                        dtype, layout, device, None, rg)
+            check_value(v.new_empty(shape), dtype, layout, device, None, False)
+            check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
+                        int64_dtype, layout, device, None, False)
+            check_value(torch.empty_like(v), dtype, layout, device, None, False)
+            check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
+                        int64_dtype, layout, device, None, False)
+
+            if dtype is not torch.float16 and layout != torch.sparse_coo:
+                fv = 3
+                v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg)
+                check_value(v, dtype, layout, device, fv, rg)
+                check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, False)
+                out = v.new()
+                check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
+                            dtype, layout, device, fv + 2, rg)
+                check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False),
+                            int64_dtype, layout, device, fv + 3, False)
+                check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False)
+                check_value(torch.full_like(v, fv + 5,
+                                            dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
+                            int64_dtype, layout, device, fv + 5, False)
+
+# FIXME: improve load_tests() documentation here
+running_script_path = None  # type: ignore[var-annotated]
+def set_running_script_path():
+    global running_script_path
+    try:
+        running_file = os.path.abspath(os.path.realpath(sys.argv[0]))
+        if running_file.endswith('.py'):  # skip if the running file is not a script
+            running_script_path = running_file
+    except Exception:
+        pass
+
+def check_test_defined_in_running_script(test_case):
+    if running_script_path is None:
+        return
+    test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__)))
+    assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \
+        f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \
+        "accidentally import a unittest.TestCase from another file?"
+
+def load_tests(loader, tests, pattern):
+    set_running_script_path()
+    test_suite = unittest.TestSuite()
+    for test_group in tests:
+        if not DISABLE_RUNNING_SCRIPT_CHK:
+            for test in test_group:
+                check_test_defined_in_running_script(test)
+        if test_group._tests:
+            test_suite.addTest(test_group)
+    return test_suite
+
+# FIXME: document this and move it to test_serialization
+class BytesIOContext(io.BytesIO):
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args):
+        pass
+
+# Tentative value for nondet_tol for gradcheck when backward implementation
+# relies on nondeterministic operations, i.e., those listed here:
+# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
+#
+# For more information see https://github.com/pytorch/pytorch/issues/56202
+GRADCHECK_NONDET_TOL = 1e-12
+
+TEST_WITH_SLOW_GRADCHECK: bool = TestEnvironment.def_flag(
+    "TEST_WITH_SLOW_GRADCHECK",
+    env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK",
+)
+
+skipIfSlowGradcheckEnv = unittest.skipIf(
+    TEST_WITH_SLOW_GRADCHECK,
+    "Tests that don't use gradcheck don't need to run on slow_gradcheck CI",
+)
+
+
+def gradcheck(fn, inputs, **kwargs):
+    # Wrapper around gradcheck that enables certain keys by default.
+    # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and
+    # forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks
+    # to be disabled to default for the public-facing api to avoid breaking user code.
+    #
+    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck.
+    default_values = {
+        "check_batched_grad": True,
+        "fast_mode": True,
+    }
+
+    if TEST_WITH_SLOW_GRADCHECK:
+        default_values["fast_mode"] = False
+
+    for key, value in default_values.items():
+        # default value override values explicitly set to None
+        k = kwargs.get(key)
+        kwargs[key] = k if k is not None else value
+
+    return torch.autograd.gradcheck(fn, inputs, **kwargs)
+
+def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
+    # Wrapper around gradgradcheck that enables certain keys by default
+    # See gradcheck above for an explanation of why we need something like this.
+    #
+    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck
+    default_values = {
+        "check_batched_grad": True,
+        "fast_mode": True,
+    }
+
+    if TEST_WITH_SLOW_GRADCHECK:
+        default_values["fast_mode"] = False
+
+    for key, value in default_values.items():
+        # default value override values explicitly set to None
+        k = kwargs.get(key)
+        kwargs[key] = k if k is not None else value
+
+    return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
+
+
+def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs):
+    # call assert function rather than returning a bool since it's nicer
+    # if we get whether this failed on the gradcheck or the gradgradcheck.
+    test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs))
+    test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs))
+
+
+@contextmanager
+def set_cwd(path: str) -> Iterator[None]:
+    old_cwd = os.getcwd()
+    try:
+        os.chdir(path)
+        yield
+    finally:
+        os.chdir(old_cwd)
+
+
+# FIXME: delete this
+# Using @toleranceOverride specific to your test is the recommended way
+# of doing this. These are just some values that worked for test_nn.
+dtype2prec_DONTUSE = {torch.float: 1e-5,
+                      torch.double: 1e-5,
+                      torch.half: 1e-2,
+                      torch.bfloat16: 1e-1}
+
+# FIXME: move to test_sparse or sparse utils
+# This is a wrapper that wraps a test to run this test twice, one with
+# coalesced=True, another with coalesced=False for coalesced/uncoalesced sparse tensors.
+def coalescedonoff(f):
+    @wraps(f)
+    def wrapped(self, *args, **kwargs):
+        f(self, *args, **kwargs, coalesced=True)
+        f(self, *args, **kwargs, coalesced=False)
+    return wrapped
+
+
+def is_coalesced_indices(s):
+    indices = s._indices()
+    hash_coeffs = (1,) + s.shape[s.sparse_dim() - 1:0:-1]
+    hash_indices = torch.tensor(hash_coeffs, device=s.device).cumprod(-1).flip(-1)
+    if s.sparse_dim() > 1:
+        hash_indices.unsqueeze_(-1)
+        hash_indices = (indices * hash_indices).sum(0)
+    else:
+        hash_indices = indices * hash_indices
+
+    # check if indices are sorted
+    res = torch.allclose(hash_indices, hash_indices.sort()[0])
+
+    # check if there are no repeated indices
+    res = res and torch.allclose(hash_indices, hash_indices.unique())
+
+    return res
+
+
+@contextlib.contextmanager
+def disable_gc():
+    if gc.isenabled():
+        try:
+            gc.disable()
+            yield
+        finally:
+            gc.enable()
+    else:
+        yield
+
+
+def find_library_location(lib_name: str) -> Path:
+    # return the shared library file in the installed folder if exist,
+    # else the file in the build folder
+    torch_root = Path(torch.__file__).resolve().parent
+    path = torch_root / 'lib' / lib_name
+    if os.path.exists(path):
+        return path
+    torch_root = Path(__file__).resolve().parents[2]
+    return torch_root / 'build' / 'lib' / lib_name
+
+def skip_but_pass_in_sandcastle(reason):
+    """
+    Similar to unittest.skip, however in the sandcastle environment it just
+    "passes" the test instead to avoid creating tasks complaining about tests
+    skipping continuously.
+    """
+    def decorator(func):
+        if not IS_SANDCASTLE:
+            func.__unittest_skip__ = True
+            func.__unittest_skip_why__ = reason
+            return func
+
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
+            return
+        return wrapper
+
+    return decorator
+
+def mock_wrapper(method):
+    """
+    Returns a function that calls the real implementation of a method
+    in addition to passing args to a mock object.
+    """
+    mock = MagicMock()
+
+    @wraps(method)
+    def wrapper(self, *args, **kwargs):
+        mock(*args, **kwargs)
+        return method(self, *args, **kwargs)
+    wrapper.mock = mock  # type: ignore[attr-defined]
+    return wrapper
+
+def get_tensors_from(args, kwargs):
+    """ Returns a set of all Tensor objects in the given args and kwargs. """
+    return set([arg for arg in args if isinstance(arg, Tensor)] +
+               [v for v in kwargs.values() if isinstance(v, Tensor)])
+
+
+# Returns scalar tensor representation of a list of integer byte values
+def bytes_to_scalar(byte_list: list[int], dtype: torch.dtype, device: torch.device):
+    dtype_to_ctype: dict[torch.dtype, Any] = {
+        torch.int8: ctypes.c_int8,
+        torch.uint8: ctypes.c_uint8,
+        torch.uint16: ctypes.c_uint16,
+        torch.uint32: ctypes.c_uint32,
+        torch.uint64: ctypes.c_uint64,
+        torch.int16: ctypes.c_int16,
+        torch.int32: ctypes.c_int32,
+        torch.int64: ctypes.c_int64,
+        torch.bool: ctypes.c_bool,
+        torch.float32: ctypes.c_float,
+        torch.complex64: ctypes.c_float,
+        torch.float64: ctypes.c_double,
+        torch.complex128: ctypes.c_double,
+    }
+    ctype = dtype_to_ctype[dtype]
+    num_bytes = ctypes.sizeof(ctype)
+
+    def check_bytes(byte_list):
+        for byte in byte_list:
+            assert 0 <= byte <= 255
+
+    if dtype.is_complex:
+        assert len(byte_list) == (num_bytes * 2)
+        check_bytes(byte_list)
+        real = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list[:num_bytes])).value
+        imag = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list[num_bytes:])).value
+        res = real + 1j * imag
+    else:
+        assert len(byte_list) == num_bytes
+        check_bytes(byte_list)
+        res = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list)).value
+
+    return torch.tensor(res, device=device, dtype=dtype)
+
+
+def copy_func(f):
+    """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
+    g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__,
+                           argdefs=f.__defaults__,
+                           closure=f.__closure__)
+    g = functools.update_wrapper(g, f)
+    g.__kwdefaults__ = f.__kwdefaults__  # type: ignore[attr-defined]
+    return g
+
+
+def xfail_inherited_tests(tests):
+    """
+    Given a list of test names which are defined by a superclass of the
+    class this decorates, mark them as expected failure.  This is useful
+    if you are doing poor man's parameterized tests by subclassing a generic
+    test class.
+    """
+    def deco(cls):
+        for t in tests:
+            # NB: expectedFailure operates by mutating the method in question,
+            # which is why you have to copy the function first
+            setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t))))
+        return cls
+    return deco
+
+
+def skip_but_pass_in_sandcastle_if(condition, reason):
+    """
+    Similar to unittest.skipIf, however in the sandcastle environment it just
+    "passes" the test instead to avoid creating tasks complaining about tests
+    skipping continuously.
+    """
+    def decorator(func):
+        if condition:
+            if IS_SANDCASTLE:
+                @wraps(func)
+                def wrapper(*args, **kwargs):
+                    print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
+                return wrapper
+            else:
+                func.__unittest_skip__ = True
+                func.__unittest_skip_why__ = reason
+
+        return func
+
+    return decorator
+
+def dtype_name(dtype):
+    """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
+    return str(dtype).split('.')[1]
+
+
+@functools.lru_cache
+def get_cycles_per_ms() -> float:
+    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
+    """
+
+    def measure() -> float:
+        start = torch.cuda.Event(enable_timing=True)
+        end = torch.cuda.Event(enable_timing=True)
+        start.record()
+        torch.cuda._sleep(1000000)
+        end.record()
+        end.synchronize()
+        cycles_per_ms = 1000000 / start.elapsed_time(end)
+        return cycles_per_ms
+
+    # Get 10 values and remove the 2 max and 2 min and return the avg.
+    # This is to avoid system disturbance that skew the results, e.g.
+    # the very first cuda call likely does a bunch of init, which takes
+    # much longer than subsequent calls.
+    #
+    # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
+    # and seems to return stable values. Therefore, we enable caching
+    # using lru_cache decorator above.
+    num = 10
+    vals = [measure() for _ in range(num)]
+    vals = sorted(vals)
+    return mean(vals[2 : num - 2])
+
+
+# OpInfo utils
+
+T = TypeVar('T')
+def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
+    """
+    Returns the first sample from an iterable of samples, like those returned by OpInfo.
+    The test will be skipped if no samples are available.
+    """
+    try:
+        return next(iter(samples))
+    except StopIteration as e:
+        raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e
+
+# this helper method is to recursively
+# clone the tensor-type input of operators tested by OpInfo
+def clone_input_helper(input):
+    if isinstance(input, torch.Tensor):
+        return torch.clone(input)
+
+    if isinstance(input, Sequence):
+        return tuple(map(clone_input_helper, input))
+
+    return input
+
+@contextmanager
+def custom_op(opname, symbolic_fn, opset_version):
+    """Context manager/decorator to test ONNX export with custom operator"""
+    try:
+        register_custom_op_symbolic(opname, symbolic_fn, opset_version)
+        yield
+    finally:
+        unregister_custom_op_symbolic(opname, opset_version)
+
+
+def outs_and_grads(fn, graph_inps, inps):
+    outs = fn(*graph_inps)
+    for out in pytree.tree_leaves(outs):
+        if isinstance(out, torch.Tensor) and out.requires_grad:
+            out.sum().backward(retain_graph=True)
+    grads = [inp.grad for inp in pytree.tree_leaves(inps) if isinstance(inp, torch.Tensor)]
+    for inp in pytree.tree_leaves(inps):
+        if isinstance(inp, torch.Tensor):
+            inp.grad = None
+    return outs, grads
+
+def compare_equal_outs_and_grads(test, m1, m2, inps):
+    r1, g1 = outs_and_grads(m1, inps, inps)
+    r2, g2 = outs_and_grads(m2, inps, inps)
+    test.assertEqual(r1, r2)
+    test.assertEqual(g1, g2)
+
+class TestGradients(TestCase):
+    exact_dtype = True
+
+    # Copies inputs to inplace operations to avoid inplace modifications
+    #   to leaves requiring gradient
+    def _get_safe_inplace(self, inplace_variant):
+        @wraps(inplace_variant)
+        def _fn(t, *args, **kwargs):
+            return inplace_variant(t.clone(), *args, **kwargs)
+
+        return _fn
+
+    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
+                      check_batched_grad=None, check_batched_forward_grad=False):
+        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
+        # NB: check_backward_ad does not affect gradgradcheck (always True)
+        if variant is None:
+            self.skipTest("Skipped! Variant not implemented.")
+        if not op.supports_dtype(dtype, torch.device(device).type):
+            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
+
+        def is_inplace(variant):
+            if hasattr(variant, "__wrapped__"):
+                return variant.__wrapped__ is op.get_inplace()
+            return variant is op.get_inplace()
+
+        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
+
+        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
+                                   small_inputs_only=TEST_WITH_SLOW_GRADCHECK)
+
+        for sample in samples:
+            if sample.broadcasts_input and is_inplace(variant):
+                continue
+
+            # Gradcheck expects tensors as its input, but autograd actually supports tensorlists
+            #   and tensors passed as kwargs. The following creates a function that accepts just
+            #   the tensors that require grad as varargs, and then recomposes them back into the
+            #   original input.
+
+            # Creates gradcheck inputs by identifying tensors requiring grad
+            all_args = None
+            if is_iterable_of_tensors(sample.input):
+                all_args = chain(sample.input, sample.args, sample.kwargs.values())
+            else:
+                all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))  # type: ignore[assignment]
+            gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))  # type: ignore[union-attr]
+
+            # Verifies sample input tensors should have no grad
+            # This may happen if the same tensor is used in two different SampleInputs
+            for t in gradcheck_args:
+                self.assertIsNone(t.grad,
+                                  "A sampled input has a gradient before running autograd. "
+                                  "This usually means that (at least) one input tensor is reused "
+                                  "across different SampleInputs. "
+                                  "Please create a new tensor for each SampleInput.")
+
+            def _input_recomposition_helper(inputs, inp, input_idx):
+                if is_iterable_of_tensors(inp):
+                    tensor_list = []
+                    for x in inp:
+                        if isinstance(x, torch.Tensor) and x.requires_grad:
+                            tensor_list.append(inputs[input_idx])
+                            input_idx = input_idx + 1
+                        else:
+                            tensor_list.append(x)
+                    return tensor_list, input_idx
+                elif isinstance(inp, torch.Tensor) and inp.requires_grad:
+                    return inputs[input_idx], input_idx + 1
+                else:
+                    return inp, input_idx
+
+            def fn(*inputs):
+                # Puts inputs back into sample properly
+                positional_args = []
+                input_idx = 0
+                inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
+                positional_args.append(inp)
+
+                for x in sample.args:
+                    inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
+                    positional_args.append(inp)
+
+                # Recreates kwargs
+                kwargs = {}
+                for k, v in sample.kwargs.items():
+                    inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
+                    kwargs[k] = inp
+
+                output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
+                if sample.output_process_fn_grad is not None:
+                    return sample.output_process_fn_grad(output)
+                return output
+
+            if check == 'gradcheck':
+                if check_batched_grad is None:
+                    check_batched_grad = op.check_batched_grad
+                self.assertTrue(gradcheck(fn, gradcheck_args,
+                                          check_batched_grad=check_batched_grad,
+                                          check_grad_dtypes=True,
+                                          nondet_tol=op.gradcheck_nondet_tol,
+                                          fast_mode=op.gradcheck_fast_mode,
+                                          check_forward_ad=check_forward_ad,
+                                          check_backward_ad=check_backward_ad,
+                                          check_undefined_grad=True,
+                                          check_batched_forward_grad=check_batched_forward_grad))
+            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
+                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
+                for gen_non_contig_grad_outputs in (False, True):
+                    kwargs = {
+                        "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
+                        "check_batched_grad": op.check_batched_gradgrad,
+                        "check_grad_dtypes": True,
+                        "nondet_tol": op.gradcheck_nondet_tol,
+                        "fast_mode": op.gradcheck_fast_mode
+                    }
+                    if check == "fwgrad_bwgrad":
+                        kwargs["check_fwd_over_rev"] = True
+                        kwargs["check_rev_over_rev"] = False
+                        kwargs["check_batched_grad"] = False
+                        kwargs["check_undefined_grad"] = False
+
+                    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
+            else:
+                self.assertTrue(False, msg="Unknown check requested!")
+
+    def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
+                          check_batched_grad=None, check_batched_forward_grad=False):
+        return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
+                                  check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
+                                  check_batched_forward_grad=check_batched_forward_grad)
+
+    def _skip_helper(self, op, device, dtype):
+        if dtype not in op.supported_backward_dtypes(torch.device(device).type):
+            self.skipTest("Skipped! Op doesn't support autograd for this dtype.")
+        if not op.supports_autograd and not op.supports_forward_ad:
+            self.skipTest("Skipped! autograd not supported.")
+
+
+
+
+# Base TestCase for NT tests; used to define common helpers, etc.
+class NestedTensorTestCase(TestCase):
+    def assertEqualIgnoringNestedInts(self, a, b):
+        # unbinding NJTs allows us to compare them as essentially equal without
+        # caring about exact nested int comparison
+        def _unbind_njts(x):
+            if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
+                return x.unbind()
+            else:
+                return x
+
+        self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
+
+    def assertEqualNoncontigAware(self, a, b):
+        # assertEqual() doesn't take into account lengths, so hack around this
+        # by comparing unbound components and shapes
+        self.assertEqualIgnoringNestedInts(a, b)
+
+        def _get_njt_shapes(x):
+            return (
+                x.shape
+                if isinstance(x, torch.Tensor) and x.is_nested
+                else None
+            )
+
+        a_shapes = pytree.tree_map(_get_njt_shapes, a)
+        b_shapes = pytree.tree_map(_get_njt_shapes, b)
+        self.assertEqual(a_shapes, b_shapes)
+
+    @contextlib.contextmanager
+    def branch_nested_state(self):
+        """Context manager to branch and restore the nested tensor state."""
+        nested_tensor_module = torch.nested._internal.nested_tensor
+        original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
+        original_tensor_id_counter = nested_tensor_module._tensor_id_counter
+        try:
+            yield
+        finally:
+            nested_tensor_module._tensor_id_counter = original_tensor_id_counter
+            nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
+
+
+def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0):
+    from torch._dynamo.trace_rules import _as_posix_path
+
+    if file is None:
+        file = inspect.stack()[1 + skip].filename  # skip one frame
+
+    file = _as_posix_path(file)
+    s = _as_posix_path(str(e))
+
+    # Remove everything that looks like stack frames in NOT this file
+    def repl_frame(m):
+        if m.group(1) != file:
+            return ""
+        # Don't accept top-level, even for this script, these will wobble
+        # depending on how the testing script was invoked
+        if m.group(2) == "":
+            return ""
+
+        return m.group(0)
+
+    s = re.sub(r'  File "([^"]+)", line \d+, in (.+)\n(    .+\n( +[~^]+ *\n)?)+', repl_frame, s)
+    s = re.sub(r"line \d+", "line N", s)
+    s = re.sub(r".py:\d+", ".py:N", s)
+    s = re.sub(r'https:/([a-zA-Z0-9_.-]+)', r'https://\1', s)
+    s = re.sub(file, _as_posix_path(os.path.basename(file)), s)
+    s = re.sub(_as_posix_path(os.path.join(os.path.dirname(torch.__file__), "")), "", s)
+    # 3.10 CALL_FUNCTION bytecode compatibility for dynamo graph break messages
+    s = re.sub(
+        r"attempting to trace CALL_FUNCTION:.*$",
+        "attempting to trace CALL: a function call, e.g. f(x, y):",
+        s,
+        flags=re.MULTILINE,
+    )
+    if suppress_suffix:
+        s = re.sub(r"\n*Set TORCH_LOGS.+", "", s, flags=re.DOTALL)
+        s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
+        s = re.sub(r"\n*Set TORCHDYNAMO_VERBOSE=1.+", "", s, flags=re.DOTALL)
+    if suppress_prefix:
+        s = re.sub(r"Cannot export model.+\n\n", "", s)
+    s = re.sub(r" +$", "", s, flags=re.MULTILINE)
+    return s
+
+
+@contextmanager
+def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
+    """Wrap around operations you want to ensure are not leaking tensor memory.
+
+    This code intentionally ignores other reference cycles, which can be benign and which we have plenty
+    of in pytorch code.  It focuses on any reference cycles that directly or indirectly result holding a Tensor alive,
+    since this is likely a more serious leak than typical python refcycles.
+
+    limit specifies how many tensors to dump debug graphs for (default=1)
+    """
+    def match_obj(obj):
+        return isinstance(obj, matched_type)
+
+    try:
+        gc.collect()
+        gc.set_debug(gc.DEBUG_SAVEALL)
+        garbage_objs = []  # type: ignore[var-annotated]
+
+        # run the user code, after cleaning any existing refcycles, and then check for new ones
+        # also allow usercode to check the garbage objs (e.g. for assertion) after exiting ctxmgr
+        yield garbage_objs
+
+        gc.collect()
+        garbage_objs.extend(filter(match_obj, gc.garbage))
+        num_garbage_objs = len(garbage_objs)
+        if num_garbage_objs > 0:
+            warnings.warn(
+                f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?", stacklevel=2
+            )
+            try:
+                import objgraph  # type: ignore[import-not-found,import-untyped]
+                warnings.warn(
+                    f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png", stacklevel=2
+                )
+                for g in garbage_objs[:limit]:
+                    objgraph.show_backrefs([g], max_depth=10)
+            except ImportError:
+                warnings.warn("`pip install objgraph` to enable memory leak debugging", stacklevel=2)
+
+    finally:
+        gc.set_debug(0)
+
+
+def remove_cpp_extensions_build_root():
+    """
+    Removes the default root folder under which extensions are built.
+    """
+    default_build_root = cpp_extension.get_default_build_root()
+    if os.path.exists(default_build_root):
+        if IS_WINDOWS:
+            # rmtree returns permission error: [WinError 5] Access is denied
+            # on Windows, this is a workaround
+            subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE)
+        else:
+            shutil.rmtree(default_build_root, ignore_errors=True)
+
+
+def install_cpp_extension(extension_root):
+    # Wipe the build / install dirs if they exist
+    build_dir = os.path.join(extension_root, "build")
+    install_dir = os.path.join(extension_root, "install")
+    for d in (build_dir, install_dir):
+        if os.path.exists(d):
+            shutil.rmtree(d)
+
+    # Build the extension
+    cmd = [sys.executable, "-m", "pip", "install", extension_root, "-v", "--no-build-isolation", "--root", install_dir]
+    return_code = shell(cmd, cwd=extension_root, env=os.environ)
+    if return_code != 0:
+        raise RuntimeError(f"build failed for cpp extension at {extension_root}")
+
+    mod_install_dir = None
+    # install directory is the one that is named site-packages
+    for root, directories, _ in os.walk(install_dir):
+        for directory in directories:
+            if "-packages" in directory:
+                mod_install_dir = os.path.join(root, directory)
+
+    if mod_install_dir is None:
+        raise RuntimeError(f"installation failed for cpp extension at {extension_root}")
+
+    if mod_install_dir not in sys.path:
+        sys.path.insert(0, mod_install_dir)
+
+
+# Decorator to provide a helper to load inline extensions to a temp directory
+def scoped_load_inline(func):
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        def load_inline(*args, **kwargs):
+            if IS_WINDOWS:
+                # TODO(xmfan): even using TemporaryDirectoryName will result in permission error
+                return cpp_extension.load_inline(*args, **kwargs)
+
+            assert "build_directory" not in kwargs
+            with TemporaryDirectoryName() as temp_dir_name:
+                if kwargs.get("verbose", False):
+                    print(f'Using temporary extension directory {temp_dir_name}...', file=sys.stderr)
+                kwargs["build_directory"] = temp_dir_name
+                return cpp_extension.load_inline(*args, **kwargs)
+
+        return func(*args, load_inline=load_inline, **kwargs)
+    return wrapper
+
+def recover_orig_fp32_precision(fn):
+    @contextlib.contextmanager
+    def recover():
+        old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision  # type: ignore[attr-defined]
+        old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision  # type: ignore[attr-defined]
+        old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision  # type: ignore[attr-defined]
+        old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision  # type: ignore[attr-defined]
+        old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision  # type: ignore[attr-defined]
+        old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision
+        try:
+            yield
+        finally:
+            torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p  # type: ignore[attr-defined]
+            torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p  # type: ignore[attr-defined]
+            torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p  # type: ignore[attr-defined]
+            torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p  # type: ignore[attr-defined]
+            torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p  # type: ignore[attr-defined]
+            torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
+
+    return recover()(fn)
+
+def skipIfPythonVersionMismatch(predicate):
+    vi = sys.version_info
+
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if predicate(vi.major, vi.minor, vi.micro):
+                return fn(self, *args, **kwargs)
+            else:
+                raise unittest.SkipTest("Python version mismatch")
+        return wrap_fn
+    return dec_fn
+
+# Decorator to patch multiple test class members for the duration of the subtest
+def patch_test_members(updates: dict[str, Any]):
+    def decorator(test_func):
+        @wraps(test_func)
+        def wrapper(self, *args, **kwargs):
+            # Store the original values of the specified members
+            original_values = {member: getattr(self, member) for member in updates}
+
+            # Update the members before running the subtest
+            for member, value in updates.items():
+                setattr(self, member, value)
+
+            # Run the test function, allowing subtests to run
+            try:
+                return test_func(self, *args, **kwargs)
+            finally:
+                # Restore the original values of the specified members after the subtest finishes
+                for member, original_value in original_values.items():
+                    setattr(self, member, original_value)
+
+        return wrapper
+    return decorator
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..32982d0a3e2a358a2530abd234b37a24c6efe77d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py
@@ -0,0 +1,585 @@
+# mypy: allow-untyped-defs
+import torch
+import functools
+from torch.testing import make_tensor
+from torch.testing._internal.opinfo.core import (
+    OpInfo,
+    SampleInput,
+)
+from torch.testing._internal.common_dtype import all_types_and
+import numpy as np
+from torch.testing._internal.autograd_function_db import (
+    sample_inputs_numpy_cube,
+    sample_inputs_numpy_mul,
+    sample_inputs_numpy_mul_scalar,
+    sample_inputs_numpy_sort,
+    sample_inputs_numpy_take,
+)
+from torch import Tensor
+from torch.types import Number
+from typing import *  # noqa: F403
+
+# Note: [custom op db]
+#
+# This is a collection of custom operator test cases written as OpInfos
+# so they can easily be consumed by OpInfo-based tests to check if subsystems
+# support them correctly.
+
+def to_numpy(tensor):
+    return tensor.cpu().numpy()
+
+@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=())
+def numpy_cube(x: Tensor) -> tuple[Tensor, Tensor]:
+    x_np = to_numpy(x)
+    dx = torch.tensor(3 * x_np ** 2, device=x.device)
+    return torch.tensor(x_np ** 3, device=x.device), dx
+
+@numpy_cube.register_fake
+def _(x):
+    return x.clone(), x.clone()
+
+def numpy_cube_setup_context(ctx, inputs, output):
+    x, = inputs
+    _cube, dx = output
+    ctx.save_for_backward(x, dx)
+
+def numpy_cube_backward(ctx, grad_out, grad_dx):
+    x, dx = ctx.saved_tensors
+    grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
+    return grad_x
+
+numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
+
+def numpy_cube_vmap(info, in_dims, x):
+    result = numpy_cube(x)
+    return result, (in_dims[0], in_dims[0])
+
+numpy_cube.register_vmap(numpy_cube_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
+def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
+    return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
+
+@numpy_mul.register_fake
+def _(x, y):
+    assert x.device == y.device
+    return (x * y).contiguous()
+
+def numpy_mul_setup_context(ctx, inputs, output):
+    ctx.save_for_backward(*inputs)
+
+def numpy_mul_backward(ctx, grad_out):
+    x, y = ctx.saved_tensors
+    grad_x = grad_out * y if ctx.needs_input_grad[0] else None
+    grad_y = grad_out * x if ctx.needs_input_grad[1] else None
+    return grad_x, grad_y
+
+numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
+
+def numpy_mul_vmap(info, in_dims, x, y):
+    x_bdim, y_bdim = in_dims
+    x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
+    y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
+    result = x * y
+    result = result.movedim(-1, 0)
+    return result, 0
+
+numpy_mul.register_vmap(numpy_mul_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=())
+def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor:
+    return torch.tensor(to_numpy(x) * scalar, device=x.device)
+
+@numpy_mul_scalar.register_fake
+def _(x, *, scalar):
+    return (x * scalar).contiguous()
+
+def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output):
+    ctx.scalar = keyword_only_inputs["scalar"]
+
+def numpy_mul_scalar_backward(ctx, grad_out):
+    grad_x = grad_out * ctx.scalar
+    return grad_x
+
+numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context)
+
+def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar):
+    x_bdim, = in_dims
+    x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
+    result = x * scalar
+    result = result.movedim(-1, 0)
+    return result, 0
+
+numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
+def numpy_sort(x: Tensor, dim: int) -> tuple[Tensor, Tensor, Tensor]:
+    device = x.device
+    x = to_numpy(x)
+    ind = np.argsort(x, axis=dim)
+    ind_inv = np.argsort(ind, axis=dim)
+    result = np.take_along_axis(x, ind, axis=dim)
+    return (
+        torch.tensor(result, device=device),
+        torch.tensor(ind, device=device),
+        torch.tensor(ind_inv, device=device),
+    )
+
+@numpy_sort.register_fake
+def _(x, dim):
+    return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
+
+def numpy_sort_setup_context(ctx, inputs, output):
+    _out, ind, ind_inv = output
+    ctx.dim = inputs[1]
+    ctx.save_for_backward(ind, ind_inv)
+    ctx.mark_non_differentiable(ind, ind_inv)
+
+def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
+    ind, ind_inv = ctx.saved_tensors
+    return numpy_take(grad_out, ind_inv, ind, ctx.dim), None
+
+numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
+
+def numpy_sort_vmap(info, in_dims, x, dim):
+    x_bdim, _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    dim = dim if dim >= 0 else dim + x.dim() - 1
+    result = numpy_sort(x, dim + 1)
+    return result, (0, 0, 0)
+
+numpy_sort.register_vmap(numpy_sort_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
+def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
+    device = x.device
+    x = to_numpy(x)
+    ind = to_numpy(ind)
+    return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
+
+@numpy_take.register_fake
+def _(x, ind, ind_inv, dim):
+    assert x.device == ind.device
+    assert x.device == ind_inv.device
+    assert ind.dtype == torch.long
+    assert ind_inv.dtype == torch.long
+    return torch.empty_like(x)
+
+def numpy_take_setup_context(ctx, inputs, output):
+    _x, ind, ind_inv, dim = inputs
+    ctx.dim = dim
+    ctx.save_for_backward(ind, ind_inv)
+
+def numpy_take_backward(ctx, grad_out):
+    ind, ind_inv = ctx.saved_tensors
+    grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim)
+    return grad_x, None, None, None
+
+numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
+
+def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim):
+    x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
+
+    # wrap dim
+    logical_dim = x.dim() if x_bdim is None else x_bdim - 1
+    dim = dim if dim >= 0 else dim + logical_dim
+
+    def expand_bdim(x, x_bdim):
+        if x_bdim is None:
+            return x.expand(info.batch_size, *x.shape)
+        return x.movedim(x_bdim, 0)
+
+    x = expand_bdim(x, x_bdim)
+    ind = expand_bdim(ind, ind_bdim)
+    ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
+
+    return numpy_take(x, ind, ind_inv, dim + 1), 0
+
+numpy_take.register_vmap(numpy_take_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
+def numpy_nonzero(x: Tensor) -> Tensor:
+    x_np = to_numpy(x)
+    res = np.stack(np.nonzero(x_np), axis=1)
+    if res.shape[0] <= 1:
+        raise RuntimeError("not supported")
+    return torch.tensor(res, device=x.device)
+
+@numpy_nonzero.register_fake
+def _(x):
+    ctx = torch._custom_op.impl.get_ctx()
+    i0 = ctx.create_unbacked_symint()
+    shape = [i0, x.dim()]
+    result = x.new_empty(shape, dtype=torch.long)
+    return result
+
+def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    shape = 10
+    result = make_arg(shape, low=0.9, high=2)
+    mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long)
+    with torch.no_grad():
+        result *= mask
+
+    yield SampleInput(result, args=())
+
+def numpy_nonzero_vmap(info, in_dims, x):
+    raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
+
+numpy_nonzero.register_vmap(numpy_nonzero_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=())
+def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
+    return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
+
+@numpy_view_copy.register_fake
+def _(x, shape) -> Tensor:
+    return x.clone().view(shape).clone()
+
+def numpy_view_copy_setup_context(ctx, inputs, output) -> None:
+    ctx.x_shape = inputs[0].shape
+
+def numpy_view_copy_backward(ctx, grad_out):
+    return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None
+
+numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
+
+def numpy_view_copy_vmap(info, in_dims, x, shape):
+    x_bdim, _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    x_shape = x.shape[0]
+    batch_shape = (x_shape, *shape)
+    result = numpy_view_copy(x, batch_shape)
+    return result, 0
+
+numpy_view_copy.register_vmap(numpy_view_copy_vmap)
+
+def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    result = make_arg(2, 3, 4, low=0.9, high=2)
+    yield SampleInput(result, args=([2, 12],))
+
+@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=())
+def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
+    assert len(xs) > 0
+    assert all(x.device == xs[0].device for x in xs)
+    assert all(x.dtype == xs[0].dtype for x in xs)
+    np_xs = [to_numpy(x) for x in xs]
+    np_out = np.concatenate(np_xs, axis=dim)
+    return torch.tensor(np_out, device=xs[0].device)
+
+@numpy_cat.register_fake
+def _(xs, dim):
+    assert len(xs) > 0
+    assert all(x.device == xs[0].device for x in xs)
+    assert all(x.dtype == xs[0].dtype for x in xs)
+    return torch.cat(xs, dim=dim)
+
+def numpy_cat_setup_context(ctx, inputs, output):
+    xs, dim = inputs
+    ctx.dim_sizes = [x.shape[dim] for x in xs]
+    ctx.dim = dim
+
+def numpy_cat_backward(ctx, grad_out):
+    dim_sizes = ctx.dim_sizes
+    dim = ctx.dim
+
+    splits = list(np.cumsum(dim_sizes)[:-1])
+    grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
+    return grad_xs, None
+
+numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
+
+def numpy_cat_vmap(info, in_dims, x, dim):
+    x_bdim, = in_dims
+    result = numpy_cat(x, dim)
+    return result, x_bdim
+
+numpy_cat.register_vmap(numpy_cat_vmap)
+
+def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    r0 = make_arg(2, 3, 4, low=0.9, high=2)
+    r1 = make_arg(4, 3, 4, low=0.9, high=2)
+    r2 = make_arg(5, 3, 4, low=0.9, high=2)
+    yield SampleInput([r0, r1, r2], args=(0,))
+
+@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=())
+def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]:
+    x_np = to_numpy(x)
+    arrs = np.split(x_np, splits, axis=dim)
+    return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs]
+
+@numpy_split_copy.register_fake
+def _(x, splits, dim):
+    return [xi.clone() for xi in torch.tensor_split(x, splits, dim)]
+
+def numpy_split_copy_setup_context(ctx, inputs, output):
+    _, _, dim = inputs
+    ctx.dim = dim
+
+def numpy_split_copy_backward(ctx, grad_out):
+    result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim)
+    return result, None, None
+
+numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
+
+def numpy_split_copy_vmap(info, in_dims, x, splits, dim):
+    x_bdim, _ , _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    result = numpy_split_copy(x, splits, dim + 1)
+    return result, 0
+
+numpy_split_copy.register_vmap(numpy_split_copy_vmap)
+
+def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    x = make_arg(2, 9, low=0.9, high=2)
+    yield SampleInput(x, args=([1, 3, 6], 1))
+
+@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=())
+def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> tuple[List[Tensor], int]:
+    x_np = to_numpy(x)
+    arrs = np.split(x_np, splits, axis=dim)
+    return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits)
+
+@numpy_split_copy_with_int.register_fake
+def _(x, splits, dim):
+    return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits)
+
+def numpy_split_copy_with_int_setup_context(ctx, inputs, output):
+    _, _, dim = inputs
+    ctx.dim = dim
+
+def numpy_split_copy_with_int_backward(ctx, grad_out, _):
+    return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None
+
+numpy_split_copy_with_int.register_autograd(
+    numpy_split_copy_with_int_backward,
+    setup_context=numpy_split_copy_with_int_setup_context)
+
+def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim):
+    x_bdim, _ , _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    result, len_split = numpy_split_copy_with_int(x, splits, dim + 1)
+    return (result, len_split), ([0 for _ in range(len(result))], None)
+
+numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
+def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
+    # Adapted from Ross Girshick's fast-rcnn implementation at
+    # https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
+    assert boxes.device == scores.device
+    device = boxes.device
+
+    boxes = to_numpy(boxes)
+    scores = to_numpy(scores)
+
+    N = boxes.shape[0]
+    assert boxes.shape == (N, 4)
+    assert scores.shape == (N,)
+
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+
+    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+        inds = np.where(ovr <= iou_threshold)[0]
+        order = order[inds + 1]
+
+    result = torch.tensor(np.stack(keep), device=device)
+    # Needed for data-dependent condition :(
+    assert result.size(0) >= 2
+    return result
+
+@numpy_nms.register_fake
+def _(boxes, scores, iou_threshold):
+    assert boxes.device == scores.device
+    N = boxes.shape[0]
+    assert boxes.shape == (N, 4)
+    assert scores.shape == (N,)
+
+    ctx = torch._custom_op.impl.get_ctx()
+    i0 = ctx.create_unbacked_symint()
+    result = boxes.new_empty([i0], dtype=torch.int64)
+    return result
+
+def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold):
+    raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
+
+numpy_nms.register_vmap(numpy_nms_vmap)
+
+def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
+    N = 64
+    xs = make_arg([N], low=0, high=28)
+    dx = make_arg([N], low=0, high=4)
+    ys = make_arg([N], low=0, high=28)
+    dy = make_arg([N], low=0, high=4)
+    boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad)
+    scores = make_arg([N], low=0, high=1, requires_grad=requires_grad)
+    iou_threshold = make_arg([], low=0, high=1).item()
+
+    yield SampleInput(boxes, args=(scores, iou_threshold))
+
+custom_op_db = [
+    OpInfo(
+        'NumpyCubeCustomOp',
+        op=numpy_cube._opoverload,
+        sample_inputs_func=sample_inputs_numpy_cube,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyMulCustomOp',
+        op=numpy_mul._opoverload,
+        sample_inputs_func=sample_inputs_numpy_mul,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyMulScalarCustomOp',
+        op=numpy_mul_scalar._opoverload,
+        sample_inputs_func=sample_inputs_numpy_mul_scalar,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpySortCustomOp',
+        op=numpy_sort._opoverload,
+        sample_inputs_func=sample_inputs_numpy_sort,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyTakeCustomOp',
+        op=numpy_take._opoverload,
+        sample_inputs_func=sample_inputs_numpy_take,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyNonzeroCustomOp',
+        op=numpy_nonzero._opoverload,
+        sample_inputs_func=sample_inputs_numpy_nonzero,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyNMSCustomOp',
+        op=torch.ops._torch_testing.numpy_nms,
+        sample_inputs_func=sample_inputs_numpy_nms,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyViewCopyCustomOp',
+        op=torch.ops._torch_testing.numpy_view_copy,
+        sample_inputs_func=sample_inputs_numpy_view_copy,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyCatCustomOp',
+        op=torch.ops._torch_testing.numpy_cat,
+        sample_inputs_func=sample_inputs_numpy_cat,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpySplitCopyCustomOp',
+        op=torch.ops._torch_testing.numpy_split_copy,
+        sample_inputs_func=sample_inputs_numpy_split_copy,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpySplitCopyWithIntCustomOp',
+        op=torch.ops._torch_testing.numpy_split_copy_with_int,
+        sample_inputs_func=sample_inputs_numpy_split_copy,
+        dtypes=all_types_and(torch.bool, torch.half),
+        gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0],
+        supports_autograd=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_out=False,
+    ),
+]
+
+
+# ==============================================================
+# some mechanical test cases
+# ==============================================================
+
+lib = torch.library.Library("_torch_testing", "FRAGMENT")  # noqa: TOR901
+
+lib.define("source0(Tensor x) -> Tensor")
+
+@torch.library.register_fake("_torch_testing::source0", lib=lib)
+def _(x):
+    return x.clone()
+
+lib.define("source1(Tensor x) -> Tensor")
+
+def source1_fake(x):
+    return x.clone()
+
+torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib)
+
+lib.define("source2(Tensor x) -> Tensor")
+
+@torch.library.register_fake("_torch_testing::source2", lib=lib)
+def _(x):
+    return x.clone()
+
+lib.define("source3(Tensor x) -> Tensor")
+
+def source3_fake(x):
+    return x.clone()
+
+torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib)
+
+
+@torch.library.custom_op("_torch_testing::source4", mutates_args=())
+def source4(x: Tensor) -> Tensor:
+    return x.clone()
+
+@source4.register_fake
+def _(x):
+    return x.clone()
+
+@torch.library.custom_op("_torch_testing::source5", mutates_args=())
+def source5(x: Tensor) -> Tensor:
+    return x.clone()
+
+def source5_fake(x):
+    return x.clone()
+
+source5.register_fake(source5_fake)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..950636a3baaf66f9a6a61f2273cb5c94ffae36ee
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3db58f22468017957222f90d26454f9e9e8c7d18
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ada28d97ea08c7530d612c4b77e52e5774e21f76
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93c9fb746eef20915b1626f0c7d3d42301daa41a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eddb9a3df59f20d965da9179932117ddff2a7299
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ecba9c0adfb144b68cea8650fd3a4586d2915a2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..239856debdcfa1b78970ad9a1624ba460e4f5340
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bf0b183f35f9f3e8597d236b0c0e0d9484b26cf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2cc36a736df99b76939ae94ebc47a6f54a75fc9a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1a6aabbdb8f060b9400bd337e91744238271c17
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a147ebe021dcd76fa28bba3e963e9d3c043c5621
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..acc7005c6b9e3d64d1ca50714839b0732d41b5a5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py
@@ -0,0 +1 @@
+# mypy: allow-untyped-defs
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f148aec71ed7567350433b92b5b7302a930eef8a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7bfa3b8d7bc1332fd3c3ee47a6d806af944b647
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..60c744ac1a84cfb9220221a583a4849b6039c353
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py
@@ -0,0 +1,103 @@
+# mypy: allow-untyped-defs
+
+import sys
+from functools import partial, wraps
+
+import torch
+import torch.distributed as dist
+from torch.distributed import rpc
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    TEST_SKIPS,
+    tp_transports,
+)
+
+
+TEST_GPU_NUM = 4
+
+
+class ShardedTensorTestBase(MultiProcessTestCase):
+    @property
+    def world_size(self):
+        return TEST_GPU_NUM
+
+    def init_pg(self, backend="nccl"):
+        if backend not in ["nccl", "gloo", "mpi", "hccl"]:
+            raise RuntimeError(f"Backend {backend} not supported!")
+
+        dist.init_process_group(
+            backend=backend,
+            world_size=self.world_size,
+            rank=self.rank,
+            init_method=f"file://{self.file_name}",
+        )
+
+        # set device for nccl pg for collectives
+        if backend == "nccl":
+            torch.cuda.set_device(self.rank)
+
+    def init_rpc(self):
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            _transports=tp_transports()
+        )
+        rpc_backend_options.init_method = f"file://{self.file_name}"
+        for rank in range(self.world_size):
+            rpc_backend_options.set_device_map(
+                f"worker{rank}", {rank: self.rank, self.rank: rank}
+            )
+
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+    def init_comms(self, init_rpc=True, backend="nccl"):
+        if init_rpc:
+            self.init_rpc()
+        self.init_pg(backend=backend)
+
+    def destroy_comms(self, destroy_rpc=True):
+        # Wait for all ranks to reach here before starting shutdown.
+        dist.barrier()
+
+        if destroy_rpc:
+            rpc.shutdown()
+        dist.destroy_process_group()
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    def assert_sharded_tensor_equal(self, st1, st2):
+        st1_local_shards = st1.local_shards()
+        st2_local_shards = st2.local_shards()
+        self.assertEqual(len(st1_local_shards), len(st2_local_shards))
+        for i, st1_local_shard in enumerate(st1_local_shards):
+            self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor)
+            self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata)
+
+        self.assertEqual(st1.metadata(), st2.metadata())
+        self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
+        self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
+
+
+# wrapper to initialize comms (processgroup + rpc)
+def with_comms(func=None, init_rpc=True, backend="nccl"):
+    if func is None:
+        return partial(
+            with_comms,
+            init_rpc=init_rpc,
+            backend=backend,
+        )
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        if backend == "nccl" and torch.cuda.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+        self.init_comms(init_rpc=init_rpc, backend=backend)
+        func(self, *args, **kwargs)
+        self.destroy_comms(destroy_rpc=init_rpc)
+
+    return wrapper
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2308e8f6050dda229412a660ccee9daf28a6c041
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3f2f9207d618d4771c28a7b90f14ad3f096cf5d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e90b8252ade94b262a2a61fbb3c25a09f00a47e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83bc3a35102a051d42587352c2dcb7967510903
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py
@@ -0,0 +1,137 @@
+# mypy: allow-untyped-defs
+
+import builtins
+
+import torch
+from torch.distributed._shard.sharding_spec import (
+    ChunkShardingSpec,
+    EnumerableShardingSpec,
+    ShardMetadata,
+)
+from torch.distributed._shard.sharding_spec._internals import (
+    get_chunked_dim_size,
+    get_split_size,
+)
+
+
+def generate_chunk_sharding_specs_for_test(sharding_dim):
+    return [
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+            ],
+        ),
+        # Test different ordering. (Case 1)
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+            ],
+        ),
+        # Test different ordering. (Case 2)
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:3/cuda:3",
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+            ],
+        ),
+    ]
+
+
+def generate_enumerable_sharding_specs_for_test():
+    return [
+        EnumerableShardingSpec(
+            [
+                ShardMetadata(
+                    shard_offsets=[0, 0],
+                    shard_sizes=[5, 5],
+                    placement="rank:0/cuda:0",
+                ),
+                ShardMetadata(
+                    shard_offsets=[5, 0],
+                    shard_sizes=[5, 5],
+                    placement="rank:1/cuda:1",
+                ),
+                ShardMetadata(
+                    shard_offsets=[0, 5],
+                    shard_sizes=[5, 5],
+                    placement="rank:2/cuda:2",
+                ),
+                ShardMetadata(
+                    shard_offsets=[5, 5],
+                    shard_sizes=[5, 5],
+                    placement="rank:3/cuda:3",
+                ),
+            ]
+        )
+    ]
+
+
+def generate_local_weight_sharding_params_for_test(
+    local_weight, sharded_dim, gpu_num, spec, rank
+):
+    """
+    Shard the local weight based the given spec, so we can compare against
+    the one from sharded tensor.
+
+    Args:
+        local_weight: weight matrix to be sharded.
+        sharded_dim: The dimension which we shard on.
+        gpu_num: number of ranks.
+        spec: sharding spec.
+        rank: # of cuda process.
+
+    Returns:
+        start_pos: start position of sharded weight on the given rank.
+        chunk_size: chunk size of sharded weight on the given rank.
+    """
+    sharding_dim_size = local_weight.size(sharded_dim)
+    split_size = get_split_size(sharding_dim_size, gpu_num)
+    current_offsets = 0
+    start_pos = current_offsets
+    for idx, placement in enumerate(spec.placements):
+        chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
+        if rank == placement.rank():
+            start_pos = current_offsets
+            break
+        current_offsets += chunk_size
+    return start_pos, chunk_size
+
+
+def clone_module_parameter(module, param_name):
+    """
+    Clone a parameter from a given existing module.
+
+    Args:
+        module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
+        param_name (str): Name of the parameter of ``module`` that needs to be cloned.
+
+    Returns: cloned tensor as :class:`torch.nn.Parameter`.
+    """
+    tensor = getattr(module, param_name)
+    return torch.nn.Parameter(tensor.detach().clone())
+
+
+def gen_binary_op_func(python_op, inplace=False):
+    src_lines = ["def f(lhs, rhs):"]
+    if "torch" in python_op:
+        src_lines.append(f"  return {python_op}(lhs, rhs)\n")
+    elif inplace:
+        src_lines.append(f"  lhs {python_op}= rhs\n  return lhs\n")
+    else:
+        src_lines.append(f"  return lhs {python_op} rhs\n")
+
+    code_str = "\n".join(src_lines)
+    g = {"torch": torch}
+    builtins.exec(code_str, g)
+    return g["f"]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe82a8dc43f8f876cb4c8d0c000cda9a32d46fb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py
@@ -0,0 +1,56 @@
+# mypy: allow-untyped-defs
+
+import copy
+import random
+
+import torch
+from torch.distributed._shard import sharded_tensor
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+
+
+PLACEMENTS = [
+    "rank:0/cuda:0",
+    "rank:1/cuda:1",
+    "rank:2/cuda:2",
+    "rank:3/cuda:3",
+]
+
+DEFAULT_GPU_NUM = 4
+
+
+def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
+    spec_list = []
+    for i in range(len(sharding_dims)):
+        random.Random(seed + i).shuffle(PLACEMENTS)
+        spec_list.append(
+            ChunkShardingSpec(
+                dim=sharding_dims[i],
+                placements=copy.deepcopy(PLACEMENTS),
+            )
+        )
+    return spec_list
+
+
+class MyShardedModel2(torch.nn.Module):
+    def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
+        super().__init__()
+        if spec is not None:
+            self.sharded_tensor2 = sharded_tensor.rand(
+                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
+            )
+        else:
+            self.sharded_tensor2 = None
+        self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
+
+
+class MyShardedModel1(torch.nn.Module):
+    def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
+        super().__init__()
+        if spec is not None:
+            self.sharded_tensor1 = sharded_tensor.rand(
+                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
+            )
+        else:
+            self.sharded_tensor1 = None
+        self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2))
+        self.submodule = MyShardedModel2(spec, group, init_rrefs)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9390da489851872ec1d0715a0b3e46275e5752b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py
@@ -0,0 +1,41 @@
+# mypy: allow-untyped-defs
+
+import torch
+import torch.nn as nn
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+
+class SimpleMegatronLM(nn.Module):
+    def __init__(self, linear_size, rank=None, dtype=torch.float32):
+        super().__init__()
+        self.fc1 = nn.Linear(*linear_size[0], dtype=dtype)
+        self.gelu = nn.GELU()
+        self.fc2 = nn.Linear(*linear_size[1], dtype=dtype)
+        if rank is not None:
+            self.fc1.cuda(rank)
+            self.fc2.cuda(rank)
+
+    def forward(self, inp):
+        return self.fc2(self.gelu(self.fc1(inp)))
+
+    def get_weights(self):
+        if isinstance(self.fc1.weight, ShardedTensor):
+            weight1 = self.fc1.weight.local_tensor()
+        else:
+            weight1 = self.fc1.weight
+
+        if isinstance(self.fc2.weight, ShardedTensor):
+            weight2 = self.fc2.weight.local_tensor()
+        else:
+            weight2 = self.fc2.weight
+
+        return (weight1, weight2)
+
+    def get_biases(self):
+        return (self.fc1.bias, self.fc2.bias)
+
+    def get_weight_grads(self):
+        return (self.fc1.weight.grad, self.fc2.weight.grad)
+
+    def get_bias_grads(self):
+        return (self.fc1.bias.grad, self.fc2.bias.grad)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5434c73a46ca652d69bd6aea934353ad284b2ff6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b56ec429b99ef48a31db27fafdfea2711c891610
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c749ca2d541659cb0b9ef67242b48aa235831cb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -0,0 +1,1019 @@
+# mypy: allow-untyped-defs
+
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import contextlib
+import copy
+import functools
+import itertools
+import sys
+import types
+from collections.abc import Callable, Iterator, Sequence
+from dataclasses import dataclass
+from functools import partial, wraps
+from typing import Any, cast, Optional, TypeVar, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed._local_tensor import (
+    LocalIntNode,
+    LocalTensor,
+    LocalTensorMode,
+    maybe_disable_local_tensor_mode,
+    maybe_run_for_local_tensor,
+)
+from torch.distributed.tensor import (
+    DeviceMesh,
+    distribute_tensor,
+    DTensor,
+    init_device_mesh,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
+from torch.distributed.tensor._redistribute import redistribute_local_tensor
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    parallelize_module,
+    PrepareModuleInput,
+    RowwiseParallel,
+    SequenceParallel,
+)
+from torch.testing._internal.common_distributed import (
+    ACCELERATOR_DIST_BACKENDS,
+    MultiProcContinuousTest,
+    MultiProcessTestCase,
+    MultiThreadedTestCase,
+    run_subtests,
+    skip_if_lt_x_gpu,
+    TEST_SKIPS,
+)
+from torch.testing._internal.common_utils import (
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_PRIVATEUSE1,
+    TEST_XPU,
+)
+from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
+
+
+DEVICE_COUNT: int
+
+if TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1:
+    DEVICE_TYPE = torch.accelerator.current_accelerator().type
+    DEVICE_COUNT = torch.accelerator.device_count()
+    PG_BACKEND = dist.Backend.default_device_backend_map[DEVICE_TYPE]
+else:
+    DEVICE_TYPE = "cpu"
+    PG_BACKEND = "gloo"
+
+NUM_DEVICES = 4
+
+# We use this as a proxy for "multiple GPUs exist"
+if (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1) and DEVICE_COUNT > 1:
+    # when we actually have multiple GPUs, relax the requirement to smaller counts.
+    NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
+
+T = TypeVar("T")
+
+
+# simple RMSNorm layer for testing
+class RMSNormPython(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        super().__init__()
+        self.eps = eps
+        self.weight = torch.nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x):
+        output = self._norm(x)
+        return output * self.weight
+
+
+class MLPModule(nn.Module):
+    def __init__(self, device, bias: bool = True):
+        super().__init__()
+        torch.manual_seed(5)
+        self.net1 = nn.Linear(10, 16, bias=bias, device=device)
+        self.relu = nn.ReLU()
+        self.net2 = nn.Linear(16, 10, bias=bias, device=device)
+
+    def forward(self, x):
+        return self.net2(self.relu(self.net1(x)))
+
+    def reset_parameters(self):
+        self.net1.reset_parameters()
+        self.net2.reset_parameters()
+
+
+class MLPStacked(nn.Module):
+    def __init__(self, device, n_layers: int = 2):
+        super().__init__()
+        self.layers = nn.ModuleList([MLPModule(device) for i in range(n_layers)])
+
+    def forward(self, x):
+        for layer in self.layers:
+            x = layer(x)
+        return x
+
+
+@dataclass
+class ModelArgs:
+    n_layers: int = 2
+    vocab_size: int = 8
+    max_seq_len: int = 16
+    dim: int = 16
+    n_heads: int = 4
+    dropout_p: float = 0.1
+    use_attn_mask: bool = True
+    weight_tying: bool = True
+    checkpoint_activations: bool = False
+
+
+class Attention(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.dim % args.n_heads == 0
+        self.head_dim = args.dim // args.n_heads
+        self.n_heads = args.n_heads
+        self.dropout_p = args.dropout_p
+        self.resid_dropout = nn.Dropout(args.dropout_p)
+        self.use_attn_mask = args.use_attn_mask
+
+        self.wq = nn.Linear(args.dim, args.dim, bias=False)
+        self.wk = nn.Linear(args.dim, args.dim, bias=False)
+        self.wv = nn.Linear(args.dim, args.dim, bias=False)
+        self.wo = nn.Linear(args.dim, args.dim, bias=False)
+
+    def forward(self, x):
+        bsz, seq_len, _ = x.size()
+        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
+        queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
+        keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
+        values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
+
+        queries = queries.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        keys = keys.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        values = values.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+
+        output = F.scaled_dot_product_attention(
+            queries,
+            keys,
+            values,
+            None,
+            self.dropout_p if self.training else 0,
+            self.use_attn_mask,
+        )
+        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
+        return self.resid_dropout(self.wo(output))
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, hidden_dim, dropout_p):
+        super().__init__()
+        self.w1 = nn.Linear(dim, hidden_dim)
+        self.gelu = nn.GELU()
+        self.w2 = nn.Linear(hidden_dim, dim)
+        self.resid_dropout = nn.Dropout(dropout_p)
+
+    def forward(self, x):
+        return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        self.attention_norm = nn.LayerNorm(args.dim)
+        self.attention = Attention(args)
+        self.ffn_norm = nn.LayerNorm(args.dim)
+        self.feed_forward = FeedForward(
+            args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
+        )
+
+    def forward(self, x):
+        h = x + self.attention(self.attention_norm(x))
+        out = h + self.feed_forward(self.ffn_norm(h))
+        return out
+
+
+# A toy transformer model, partly inspired by the nanoGPT model:
+# https://github.com/karpathy/nanoGPT.
+class Transformer(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.vocab_size is not None
+        assert args.max_seq_len is not None
+        self.model_args = args
+        self.max_seq_len = args.max_seq_len
+        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
+        self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
+        self.dropout = nn.Dropout(args.dropout_p)
+        self.layers = nn.ModuleList()
+        for _ in range(args.n_layers):
+            self.layers.append(TransformerBlock(args))
+        self.norm = nn.LayerNorm(args.dim)
+        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
+        if args.weight_tying:
+            self.output.weight = self.tok_embeddings.weight
+        self.checkpoint_activations = args.checkpoint_activations
+
+    def forward(self, tokens):
+        _bsz, seq_len = tokens.size()
+        assert seq_len <= self.max_seq_len
+        h = self.tok_embeddings(tokens)
+        pos = torch.arange(0, seq_len, device=tokens.device)
+        p = self.pos_embeddings(pos)  # positional embeddings of shape (seq_len, dim)
+        h = h + p
+        h = self.dropout(h)
+        for layer in self.layers:
+            if self.checkpoint_activations:
+                h = torch.utils.checkpoint.checkpoint(layer, h, use_reentrant=False)
+            else:
+                h = layer(h)
+        h = self.norm(h)
+        output = self.output(h).float()
+        return output
+
+    @staticmethod
+    def parallelize(
+        module: "Transformer",
+        device_mesh: DeviceMesh,
+        use_seq_parallel: bool,
+        local_output_for_attn: bool = False,
+    ) -> nn.Module:
+        assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
+        # Parallelize the root submodules.
+        if use_seq_parallel:
+            root_plan = {
+                "tok_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Shard(1)
+                ),
+                "pos_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Shard(0)
+                ),
+                "norm": SequenceParallel(),
+            }
+        else:
+            root_plan = {
+                "tok_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Replicate()
+                ),
+                "pos_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Replicate()
+                ),
+            }
+
+        module_tp = parallelize_module(module, device_mesh, root_plan)
+        # Parallelize the attention and feed forward submodules.
+        for layer in module_tp.layers:
+            layer_parallelize_plan = {}
+            if use_seq_parallel:
+                layer_parallelize_plan["attention"] = PrepareModuleInput(
+                    input_layouts=Shard(1),
+                    desired_input_layouts=Replicate(),
+                )
+                # shard the RMSNorms
+                layer_parallelize_plan["attention_norm"] = SequenceParallel()
+                layer_parallelize_plan["ffn_norm"] = SequenceParallel()
+            layer_parallelize_plan["attention.wq"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wk"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wv"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wo"] = (
+                RowwiseParallel(output_layouts=Shard(1))
+                if use_seq_parallel
+                else RowwiseParallel()
+            )
+
+            layer_parallelize_plan["feed_forward.w1"] = (
+                ColwiseParallel(input_layouts=Shard(1))
+                if use_seq_parallel
+                else ColwiseParallel()
+            )
+            layer_parallelize_plan["feed_forward.w2"] = (
+                RowwiseParallel(output_layouts=Shard(1))
+                if use_seq_parallel
+                else RowwiseParallel()
+            )
+
+            parallelize_module(layer, device_mesh, layer_parallelize_plan)
+
+        # Parallelize the output submodule. If weight tying is enabled, we need to
+        # make sure output.weight is sharded consistently as tok_embeddings.weight,
+        # at the cost of the all_reduce operation using RowwiseParallel.
+        output_parallelize_plan = (
+            ColwiseParallel(
+                input_layouts=Shard(1),
+                output_layouts=Replicate(),
+            )
+            if use_seq_parallel
+            else ColwiseParallel(output_layouts=Replicate())
+        )
+        parallelize_module(module_tp.output, device_mesh, output_parallelize_plan)
+
+        if local_output_for_attn:
+            for layer in module_tp.layers:
+                layer.attention.n_heads = (
+                    module_tp.model_args.n_heads // device_mesh.size()
+                )
+
+        # Manually set output.weight so that parameters and gradients are shared.
+        if module_tp.model_args.weight_tying:
+            module_tp.output.weight = module_tp.tok_embeddings.weight
+
+        return module_tp
+
+
+def skip_unless_torch_gpu(method: T) -> T:
+    """
+    Test decorator which skips the test unless there's a GPU available to torch.
+
+    >>> # xdoctest: +SKIP
+    >>> @skip_unless_torch_gpu
+    >>> def test_some_method(self) -> None:
+    >>>   ...
+    """
+    # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
+    return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))
+
+
+class DTensorContinuousTestBase(MultiProcContinuousTest):
+    @classmethod
+    def device_type(cls) -> str:
+        # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
+        if (
+            not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
+            or DEVICE_COUNT < cls.world_size
+        ):
+            return "cpu"
+        else:
+            return DEVICE_TYPE
+
+    @classmethod
+    def backend_str(cls) -> str:
+        backend = dist.get_default_backend_for_device(DEVICE_TYPE)
+        return backend
+
+
+class DTensorTestBase(MultiProcessTestCase):
+    @property
+    def is_local_tensor_enabled(self) -> bool:
+        return False
+
+    @property
+    def world_size(self) -> int:
+        return NUM_DEVICES
+
+    @property
+    def device_type(self) -> str:
+        # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
+        if (
+            not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
+            or DEVICE_COUNT < self.world_size
+        ):
+            return "cpu"
+        else:
+            return DEVICE_TYPE
+
+    @property
+    def backend(self) -> str:
+        backend = dist.get_default_backend_for_device(self.device_type)
+        return backend
+
+    def init_manual_seed_for_rank(self) -> None:
+        torch.manual_seed(self.rank)
+
+    def build_device_mesh(self) -> DeviceMesh:
+        return init_device_mesh(self.device_type, (self.world_size,))
+
+    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
+        if backend is None:
+            backend = self.backend
+
+        requires_gpu = any(
+            gpu_backend in backend for gpu_backend in ACCELERATOR_DIST_BACKENDS
+        )
+        if requires_gpu and torch.accelerator.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        curr_backend = dist.get_default_backend_for_device(self.device_type)
+
+        if backend not in [
+            "nccl",
+            "gloo",
+            "mpi",
+            f"cpu:gloo,{self.device_type}:{curr_backend}",
+            "hccl",
+            "xccl",
+            "fake",
+            "cpu:gloo,xpu:xccl",
+        ]:
+            raise RuntimeError(f"Backend {backend} not supported!")
+
+        device_id = None
+        if "nccl" in backend or "xccl" in backend:
+            # set device for nccl pg for collectives
+            # TODO: if users want to enable testing across hosts, we may need
+            # to change this part.
+            torch.accelerator.set_device_index(self.rank)
+            # we only need to set device_id for nccl backend with eager init
+            device_id = (
+                torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
+            )
+
+        # For nccl backend, bind the device to the process if device_id is not None
+        # so the nccl communicator is immediately formed and we can use `ncclCommSplit`
+        # for form subgroup to avoid unnecessary overhead.
+        dist.init_process_group(
+            backend=backend,
+            world_size=self.world_size,
+            rank=self.rank,  # pyre-ignore[16]
+            init_method=f"file://{self.file_name}",  # pyre-ignore[16]
+            device_id=device_id,
+        )
+
+    def destroy_pg(self, device_id: Optional[int] = None) -> None:
+        # Wait for all ranks to reach here before starting shutdown.
+        # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895
+        # dist.all_reduce(torch.zeros((1,), device="cuda" if TEST_CUDA else "cpu"))
+        # FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
+        #  test_dtensor.py  -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
+        if device_id is None:
+            device_id = (
+                torch.cuda.current_device() if self.device_type == "cuda" else self.rank
+            )
+
+        if self.device_type == "cpu":
+            # NOTE: when `device_id` is not None, barrier() will choose the accelerator
+            # of the most pripority, which means if the test specifies to use CPU for
+            # testing while CUDA is available on the host, the barrier() will use CUDA.
+            # To avoid this and better respect `self.device_type`, we add this branch to
+            # enforce barrier() to use CPU when `self.device_type` is CPU and other
+            # accelerator is also available.
+            dist.barrier()
+        else:
+            dist.barrier(device_ids=[device_id])
+
+        dist.destroy_process_group()
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None:
+        """
+        This function checks ``op_call(dtensor).full_tensor() == op_call(dtensor.full_tensor())``.
+        Unlike _test_op where the DTensor sharding is generated by DTensorConverter,
+        this function takes in DTensor object directly as argument and test the equality
+        of calling op on full_tensor() and DTensor.
+        """
+        # call full_tensor() on DTensor args/kwargs
+        args_flattened, args_spec = tree_flatten(args)
+        full_tensor_args_flattened = tuple(
+            arg.full_tensor().detach().clone() if isinstance(arg, DTensor) else arg
+            for arg in args_flattened
+        )
+        full_tensor_args = tree_unflatten(full_tensor_args_flattened, args_spec)
+        full_tensor_kwargs = {
+            k: v.full_tensor() if isinstance(v, DTensor) else v
+            for k, v in kwargs.items()
+        }
+
+        out_flattened, _ = tree_flatten(
+            op_call(*full_tensor_args, **full_tensor_kwargs)
+        )
+        d_out_flattened, _ = tree_flatten(op_call(*args, **kwargs))
+        d_out_full_tensor_flattened = [dt.full_tensor() for dt in d_out_flattened]
+        self.assertEqual(out_flattened, d_out_full_tensor_flattened)
+
+    # pyre-ignore[2]:
+    def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
+        out = op_call(*args, **kwargs)
+        dtc = DTensorConverter(mesh, args, kwargs)
+        for d_args, d_kwargs in dtc:
+            # pyre can't find assertTrue anymore?
+            self.assertEqual(dtc.successful(), True)
+            d_out = op_call(*d_args, **d_kwargs)
+            self.assertEqual(d_out.full_tensor(), out)
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+
+TestFunc = Callable[[...], object]
+
+
+# wrapper to initialize comms (processgroup)
+def with_comms(
+    eager_init: Union[TestFunc, bool] = False, backend: Optional[str] = None
+) -> TestFunc:
+    def decorator(func, eager_init: bool = False, backend: Optional[str] = None):
+        @wraps(func)  # pyre-ignore[6]
+        def wrapper(
+            self,
+            *args: tuple[object],
+            **kwargs: dict[str, Any],  # type: ignore[misc]
+        ) -> None:
+            # just passthrough if harness doesn't
+            # support init_pg e.g., DTensorOpTestBase
+            if not hasattr(self, "init_pg"):
+                func(self, *args, **kwargs)
+                return
+
+            self.init_pg(eager_init, backend)
+
+            try:
+                func(self, *args, **kwargs)  # type: ignore[misc]
+            except Exception as e:
+                dist.destroy_process_group()
+                raise e
+
+            self.destroy_pg()
+
+        return wrapper
+
+    return (
+        decorator(func=eager_init)
+        if callable(eager_init)
+        else partial(decorator, eager_init=eager_init, backend=backend)
+    )
+
+
+class DTensorOpTestBase(MultiThreadedTestCase):
+    @property
+    def world_size(self) -> int:
+        return NUM_DEVICES
+
+    @property
+    def device_type(self) -> str:
+        return DEVICE_TYPE
+
+    def build_device_mesh(self):
+        return init_device_mesh(self.device_type, (self.world_size,))
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_threads()
+
+
+# This is a class for converting args/kwargs of an op into distributed args/kwargs
+class DTensorConverter:
+    def __init__(
+        self,
+        mesh: DeviceMesh,
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> None:
+        self.hit = 0
+        self.miss = 0
+        self.mesh = mesh
+        self.args = args
+        self.kwargs = kwargs
+        flatten_args, flatten_args_spec = tree_flatten(args)
+        flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
+
+        self.flatten_args: list[object] = flatten_args
+        self.flatten_args_spec: TreeSpec = flatten_args_spec
+        self.flatten_kwargs: list[object] = flatten_kwargs
+        self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
+
+        choices_for_args = [
+            self.gen_sharding_choices_for_arg(arg)
+            for arg in self.flatten_args
+            if isinstance(arg, torch.Tensor)
+        ]
+
+        choices_for_args.extend(
+            self.gen_sharding_choices_for_arg(arg)
+            for arg in self.flatten_kwargs
+            if isinstance(arg, torch.Tensor)
+        )
+
+        self.sharding_combs: Iterator[Sequence[Placement]] = iter(
+            itertools.product(*choices_for_args)
+        )
+
+    def successful(self) -> bool:
+        return self.hit > 0 and self.miss == 0
+
+    def is_supported_tensor(self, t: torch.Tensor) -> bool:
+        # TODO: dist tensor need to support quantized and sparse
+        # tensors, quantized tensor might be relatively easy, but
+        # sparse tensor have special layouts that we need to possibly
+        # deal with, until we are clear about them, we don't officially
+        # support them.
+        return not any(
+            [
+                t.is_sparse_csr,
+                t.is_sparse,
+                t.is_mkldnn,
+                t.is_quantized,
+                t.is_nested,
+                torch._is_functional_tensor(t),
+                t.is_neg(),
+                t.is_conj(),
+                t.device.type in ("lazy", "meta"),
+                # We need a way to test if a tensor is batched but there
+                # is no official APi to do it
+                # torch._C._is_batched(t),
+            ]
+        )
+
+    def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
+        mesh_size = self.mesh.size()
+        sharding_choices: list[Placement] = [Replicate()]
+        # c10d collective does not support bool tensor
+        # for bool tensor we treat it as replicated
+        if arg.dtype != torch.bool:
+            # only generating choices with: replicate, or sharding
+            # evenly on a dimension that could be sharded
+            sharding_choices = sharding_choices + [
+                Shard(i)
+                for i, s in enumerate(arg.shape)
+                if s > 1 and s % mesh_size == 0
+            ]
+        # TODO: add multi mesh choices
+        # all_choices = itertools.product(
+        #     *(self.mesh.ndim * [sharding_choices])
+        # )
+        return sharding_choices
+
+    def __iter__(self) -> "DTensorConverter":
+        return self
+
+    def __next__(self) -> tuple[tuple[object, ...], dict[str, object]]:
+        try:
+            next_sharding_choices = next(self.sharding_combs)
+            idx = 0
+
+            new_args: list[object] = []
+            for arg in self.flatten_args:
+                if isinstance(arg, torch.Tensor):
+                    new_args.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_args.append(arg)
+
+            new_kwargs: list[object] = []
+            for arg in self.flatten_kwargs:
+                if isinstance(arg, torch.Tensor):
+                    new_kwargs.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_kwargs.append(arg)
+
+            return (
+                tree_unflatten(new_args, self.flatten_args_spec),
+                tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
+            )
+        except StopIteration as e:
+            raise StopIteration from e
+
+    def to_dist_tensor(
+        self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
+    ) -> torch.Tensor:
+        if type(t) is torch.Tensor or type(t) is nn.Parameter or type(t) is LocalTensor:
+            if self.is_supported_tensor(t):
+                self.hit += 1
+                if t.ndim == 0:
+                    # scalar tensor by default will be replicated
+                    r = distribute_tensor(t, mesh, [Replicate()] * mesh.ndim)
+                else:
+                    # distribute non-scalar tensors
+                    r = distribute_tensor(t, mesh, placements)
+                if isinstance(t, nn.Parameter):
+                    r = nn.Parameter(  # type: ignore[assignment]
+                        r, requires_grad=r.requires_grad
+                    )
+                return r
+            else:
+                self.miss += 1
+                return t
+        elif torch.overrides.is_tensor_like(t):
+            # Blindly converting tensor subclasses to dist tensor can cause
+            # unpredictable problems, we explicitly disable this conversion
+            # for now (i.e. we don't support DTensor holding tensor subclass
+            # until there's a strong reason later).
+            self.miss += 1
+            return t
+        else:
+            raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")
+
+
+class LocalDTensorOpTestBase(DTensorOpTestBase):
+    @property
+    def is_local_tensor_enabled(self) -> bool:
+        return True
+
+    def _handle_test_skip(self, msg: str) -> None:
+        self.skipTest(msg)
+
+    def _get_local_tensor_mode(self):
+        return LocalTensorMode(frozenset(range(self.world_size)))
+
+    def setUp(self) -> None:
+        super().setUp()
+        torch.autograd._enable_record_function(False)
+
+    def tearDown(self) -> None:
+        from torch.distributed.tensor import _random as random
+
+        random._rng_tracker = None
+        super().tearDown()
+        torch.autograd._enable_record_function(True)
+
+    @property
+    def rank(self):
+        return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
+
+    @rank.setter
+    def rank(self, rank):
+        pass
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            fn()
+
+        return types.MethodType(wrapper, self)
+
+    def build_device_mesh(self) -> DeviceMesh:
+        with maybe_disable_local_tensor_mode():
+            return super().build_device_mesh()
+
+    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
+        dist.init_process_group("fake", rank=0, world_size=self.world_size)
+        self._pg = dist.distributed_c10d._get_default_group()
+
+    def destroy_pg(self, device_id: Optional[int] = None) -> None:
+        dist.destroy_process_group(self._pg)
+        self._pg = None
+
+    def _spawn_processes(self) -> None:
+        pass
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        getattr(self, test_name)()
+
+    def init_manual_seed_for_rank(self) -> None:
+        torch.manual_seed(0)
+
+
+class LocalDTensorTestBase(DTensorTestBase):
+    @property
+    def is_local_tensor_enabled(self) -> bool:
+        return True
+
+    def _handle_test_skip(self, msg: str) -> None:
+        self.skipTest(msg)
+
+    def _get_local_tensor_mode(self):
+        return LocalTensorMode(frozenset(range(self.world_size)))
+
+    def setUp(self) -> None:
+        super().setUp()
+        torch.autograd._enable_record_function(False)
+
+    def tearDown(self) -> None:
+        from torch.distributed.tensor import _random as random
+
+        random._rng_tracker = None
+        super().tearDown()
+        torch.autograd._enable_record_function(True)
+
+    @property
+    def rank(self):
+        return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
+
+    @rank.setter
+    def rank(self, rank):
+        pass
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            fn()
+
+        return types.MethodType(wrapper, self)
+
+    def build_device_mesh(self) -> DeviceMesh:
+        with maybe_disable_local_tensor_mode():
+            return super().build_device_mesh()
+
+    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
+        dist.init_process_group("fake", rank=0, world_size=self.world_size)
+        self._pg = dist.distributed_c10d._get_default_group()
+
+    def destroy_pg(self, device_id: Optional[int] = None) -> None:
+        dist.destroy_process_group(self._pg)
+        self._pg = None
+
+    def _spawn_processes(self) -> None:
+        pass
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        getattr(self, test_name)()
+
+    def init_manual_seed_for_rank(self) -> None:
+        torch.manual_seed(0)
+
+
+def make_wrapped(fn, ctxs):
+    @functools.wraps(fn)
+    def wrapped(self):
+        torch._dynamo.reset()
+        stack = contextlib.ExitStack()
+        for ctx in ctxs:
+            if callable(ctx):
+                stack.enter_context(ctx(self))
+            else:
+                stack.enter_context(ctx)
+        try:
+            out = fn(self)
+        finally:
+            stack.close()
+        return out
+
+    return wrapped
+
+
+def create_local_tensor_test_class(
+    orig_cls, skipped_tests=None, base_class=LocalDTensorTestBase
+):
+    if skipped_tests is None:
+        skipped_tests = []
+
+    dct = orig_cls.__dict__.copy()
+    for name in list(dct.keys()):
+        fn = dct[name]
+        if not callable(fn):
+            continue
+        elif name in skipped_tests:
+            dct[name] = lambda self: self.skipTest("Skipped test")
+        elif name.startswith("test_"):
+            ctxs = [
+                lambda test: test._get_local_tensor_mode(),
+            ]
+            dct[name] = make_wrapped(fn, ctxs)
+
+    cls = type(
+        orig_cls.__name__ + "WithLocalTensor",
+        (base_class,) + orig_cls.__bases__,
+        dct,
+    )
+    cls.__file__ = __file__
+    return cls
+
+
+@maybe_run_for_local_tensor
+def map_local_tensor_for_rank(tensor, rank, func):
+    return func(tensor, rank)
+
+
+@maybe_run_for_local_tensor
+def map_local_for_rank(rank, func):
+    return func(rank)
+
+
+def reduce_local_int(val, func):
+    return func(val.node._local_ints)
+
+
+def _convert_shard_order_dict_to_ShardOrder(shard_order):
+    """Convert shard_order dict to ShardOrder"""
+    return tuple(
+        ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
+        for tensor_dim, mesh_dims in shard_order.items()
+    )
+
+
+# TODO(zpcore): remove once the native redistribute supports shard_order arg
+def redistribute(
+    dtensor_input,
+    device_mesh,
+    placements,
+    shard_order,
+    use_graph_based_transform=True,
+):
+    """
+    wrapper function to support shard_order for redistribution
+    This is a simpler version of Redistribute, only considers the forward.
+    """
+    if placements is None:
+        placements = shard_order_to_placement(shard_order, device_mesh)
+    placements = tuple(placements)
+    old_spec = dtensor_input._spec
+    new_spec = copy.deepcopy(old_spec)
+    new_spec.placements = placements
+    if shard_order is not None:
+        new_spec.shard_order = shard_order
+    else:
+        new_spec.shard_order = ()
+    if old_spec == new_spec:
+        return dtensor_input
+    dtensor_input = DTensor.from_local(
+        redistribute_local_tensor(
+            dtensor_input.to_local(),
+            old_spec,
+            new_spec,
+            use_graph_based_transform=use_graph_based_transform,
+        ),
+        device_mesh,
+    )
+    dtensor_input._spec = copy.deepcopy(new_spec)
+    return dtensor_input  # returns DTensor
+
+
+# TODO(zpcore): remove once the native distribute_tensor supports
+# shard_order arg
+def patched_distribute_tensor(
+    input_tensor,
+    device_mesh,
+    placements,
+    shard_order,
+    use_graph_based_transform=True,
+):
+    """wrapper function to support shard_order for tensor distribution"""
+    if placements is None:
+        placements = shard_order_to_placement(shard_order, device_mesh)
+    placements = tuple(placements)
+    tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
+    # fix the shard order
+    return redistribute(
+        tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
+    )
+
+
+# TODO(zpcore): remove once the native redistribute supports shard_order arg
+def make_full_tensor(dtensor_input):
+    """wrapper function to support DTensor.full_tensor"""
+    return redistribute(
+        dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
+    ).to_local()
+
+
+def shard_order_to_placement(shard_order, mesh):
+    """convert shard_order to placement with only Replicate() and Shard()"""
+    placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
+    if shard_order is not None:
+        for entry in shard_order:
+            tensor_dim = entry.tensor_dim
+            mesh_dims = entry.mesh_dims
+            for mesh_dim in mesh_dims:
+                placements[mesh_dim] = Shard(tensor_dim)
+    return tuple(placements)
+
+
+def generate_shard_orders(mesh, tensor_rank):
+    # Generate all possible sharding placement of tensor with rank
+    # `tensor_rank` over mesh.
+    def _split_list(lst: list, N: int):
+        def compositions(n: int, k: int):
+            # yields lists of length k, positive ints summing to n
+            for cuts in itertools.combinations(range(1, n), k - 1):
+                # add 0 and n as sentinels, then take consecutive differences
+                yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
+
+        length = len(lst)
+        for comp in compositions(length, N):
+            result = []
+            start = 0
+            for size in comp:
+                result.append(lst[start : start + size])
+                start += size
+            yield result
+
+    all_mesh = list(range(mesh.ndim))
+    all_device_order = list(itertools.permutations(all_mesh))
+    for device_order in all_device_order:
+        # split on device orders, and assign each device order segment to a tensor dim
+        for num_split in range(1, mesh.ndim + 1):
+            for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
+                for tensor_dims in itertools.combinations(
+                    range(tensor_rank), len(splitted_list)
+                ):
+                    shard_order = {}
+                    assert len(tensor_dims) == len(splitted_list)
+                    for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
+                        shard_order[tensor_dim] = device_order[
+                            mesh_dims[0] : mesh_dims[-1] + 1
+                        ]
+                    yield _convert_shard_order_dict_to_ShardOrder(shard_order)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6381c5a6e120bcf99b15ca851b8d94bcf10893cf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ff83771b3da6fdb21544a2178cd1de325055c5c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f14919e5382b18c41889587f7d6acb71b3f80020
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..af136fb8722d17d70767718a0cd327f71d730fda
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py
@@ -0,0 +1,754 @@
+# mypy: allow-untyped-defs
+
+import enum
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.testing._internal.dist_utils as dist_utils
+from torch import nn, Tensor
+from torch._jit_internal import Future
+from torch.distributed.nn import RemoteModule
+from torch.distributed.nn.api.remote_module import (
+    _REMOTE_MODULE_PICKLED_ATTRIBUTES,
+    _RemoteModule,
+)
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+_PARAM_VAL = torch.nn.Parameter(torch.ones(1))
+
+
+# RPC handler for querying the device on the destination worker.
+def remote_device(module_rref):
+    for param in module_rref.local_value().parameters():
+        return param.device
+
+
+# RPC handler for querying __dict__ on the destination worker.
+def remote_module_attributes(remote_module):
+    return remote_module.__dict__
+
+
+# RPC handler for running forward on the destination worker.
+def remote_forward(remote_module, args):
+    return remote_module.forward(*args)
+
+
+# RPC handler for running forward_async on the destination worker.
+def remote_forward_async(remote_module, args):
+    # Since future cannot be pickled and sent over the RPC layer,
+    # have to wait and behave just like ``forward_sync``.
+    return remote_module.forward_async(*args).wait()
+
+
+# RPC handler for getting training mode on the destination worker.
+def get_remote_training_arg(module_rref):
+    return module_rref.local_value().training
+
+
+class ModuleCreationMode(enum.Enum):
+    MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
+    MODULE_CTOR = "module_ctor"
+
+
+@torch.jit.interface
+class MyModuleInterface:
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+
+@torch.jit.interface
+class RemoteMyModuleInterface:
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+    def forward_async(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> Future[tuple[str, int, Tensor]]:
+        pass
+
+
+class MyModule(nn.Module):
+    def __init__(self, first_arg, first_kwarg=-1):
+        super().__init__()
+        self.param1 = _PARAM_VAL
+
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        return word, number, tensor
+
+
+class BadModule:
+    def __init__(self, first_arg, first_kwarg=-1):
+        pass
+
+
+def create_scripted_module(first_arg, first_kwarg=-1):
+    module = MyModule(first_arg, first_kwarg=first_kwarg)
+    scripted_module = torch.jit.script(module)
+    return scripted_module
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonRemoteModuleTest(RpcAgentTestFixture):
+    @property
+    def world_size(self):  # Override setting in RpcAgentTestFixture
+        return 2
+
+    @staticmethod
+    def _create_remote_module_iter(remote_device, modes=None):
+        if modes is None:
+            modes = ModuleCreationMode.__members__.values()
+
+        args = (1,)
+        kwargs = dict(first_kwarg=2)
+
+        if ModuleCreationMode.MODULE_CTOR in modes:
+            remote_module = RemoteModule(remote_device, MyModule, args, kwargs)
+            yield remote_module
+
+        if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
+            remote_module = _RemoteModule(
+                remote_device,
+                create_scripted_module,
+                args,
+                kwargs,
+                _module_interface_cls=MyModuleInterface,
+            )
+            scripted_remote_module = torch.jit.script(remote_module)
+            yield scripted_remote_module
+
+
+class RemoteModuleTest(CommonRemoteModuleTest):
+    @dist_utils.dist_init
+    def test_bad_module(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        remote_device = f"{dst_worker_name}/cpu"
+        args = (1,)
+        kwargs = dict(first_kwarg=2)
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,",
+        ):
+            RemoteModule(remote_device, BadModule, args, kwargs).forward()
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,",
+        ):
+            RemoteModule(remote_device, BadModule, args, kwargs).forward()
+
+    @dist_utils.dist_init
+    def test_forward_async(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2, "3")
+        for remote_module in self._create_remote_module_iter(dst_worker_name):
+            ret_fut = remote_module.forward_async(*args)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_forward_async_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            )
+        )
+
+        @torch.jit.script
+        def run_forward_async(scripted_remote_module: RemoteMyModuleInterface):
+            ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3")
+            ret = ret_fut.wait()
+            return ret
+
+        ret = run_forward_async(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+
+    @dist_utils.dist_init
+    def test_forward_sync(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2, "3")
+        for remote_module in self._create_remote_module_iter(dst_worker_name):
+            ret = remote_module.forward(*args)
+            self.assertEqual(ret, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_forward_sync_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            )
+        )
+
+        @torch.jit.script
+        def run_forward(scripted_remote_module: MyModuleInterface):
+            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
+            return ret
+
+        ret = run_forward(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+
+    @dist_utils.dist_init
+    def test_forward_with_kwargs(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2)
+        kwargs = dict(word="3")
+        # Only test Python nn.Module, because script module methods don't support taking kwargs.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            ret_fut = remote_module.forward_async(*args, **kwargs)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args + ("3",))))
+
+            ret = remote_module.forward(*args, **kwargs)
+            self.assertEqual(ret, tuple(reversed(args + ("3",))))
+
+    @dist_utils.dist_init
+    def test_remote_parameters(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # Only test Python nn.Module, because script module methods don't support ``remote_parameters``.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            param_rrefs = remote_module.remote_parameters()
+            self.assertEqual(len(param_rrefs), 1)
+            self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL))
+
+    @dist_utils.dist_init
+    def test_get_module_rref(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # Only test Python nn.Module, because script module methods don't support ``get_module_rref``.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            rref = remote_module.get_module_rref()
+            self.assertEqual(rref, remote_module.module_rref)
+            for param in rref.to_here().parameters():
+                self.assertTrue(torch.equal(param, _PARAM_VAL))
+
+    @dist_utils.dist_init
+    def test_train_eval(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            remote_module.train()
+            ret1 = rpc.rpc_sync(
+                dst_worker_name,
+                get_remote_training_arg,
+                args=(remote_module.get_module_rref(),),
+            )
+            self.assertEqual(ret1, True)
+
+            remote_module.eval()
+            ret2 = rpc.rpc_sync(
+                dst_worker_name,
+                get_remote_training_arg,
+                args=(remote_module.get_module_rref(),),
+            )
+            self.assertEqual(ret2, False)
+
+    @dist_utils.dist_init
+    def test_unsupported_methods(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``register_buffer`` not supported for RemoteModule"
+            ):
+                remote_module.register_buffer("buffer", torch.ones(5))
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_parameter`` not supported for RemoteModule",
+            ):
+                remote_module.register_parameter(
+                    "param", torch.nn.Parameter(torch.ones(1))
+                )
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``add_module`` not supported for RemoteModule"
+            ):
+                remote_module.add_module("empty", None)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``apply`` not supported for RemoteModule"
+            ):
+                fn = torch.rand((3, 3), requires_grad=False)
+                remote_module.apply(fn)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``cuda`` not supported for RemoteModule"
+            ):
+                remote_module.cuda()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``cpu`` not supported for RemoteModule"
+            ):
+                remote_module.cpu()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``type`` not supported for RemoteModule"
+            ):
+                remote_module.type(torch.FloatTensor)
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``float`` not supported for RemoteModule"
+            ):
+                remote_module.float()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``double`` not supported for RemoteModule"
+            ):
+                remote_module.double()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``bfloat16`` not supported for RemoteModule"
+            ):
+                remote_module.bfloat16()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``to`` not supported for RemoteModule"
+            ):
+                remote_module.to("cpu", dtype=torch.int32)
+
+            def hook(module, grad_input, grad_output):
+                pass
+
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_backward_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_backward_hook(hook)
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_forward_pre_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_forward_pre_hook(hook)
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_forward_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_forward_hook(hook)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``state_dict`` not supported for RemoteModule"
+            ):
+                remote_module.state_dict()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``load_state_dict`` not supported for RemoteModule"
+            ):
+                remote_module.load_state_dict({})
+
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.",
+            ):
+                remote_module.parameters()
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``named_parameters`` not supported for RemoteModule",
+            ):
+                remote_module.named_parameters()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``buffers`` not supported for RemoteModule"
+            ):
+                remote_module.buffers()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_buffers`` not supported for RemoteModule"
+            ):
+                remote_module.named_buffers()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``children`` not supported for RemoteModule"
+            ):
+                remote_module.children()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_children`` not supported for RemoteModule"
+            ):
+                remote_module.named_children()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``modules`` not supported for RemoteModule"
+            ):
+                remote_module.modules()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_modules`` not supported for RemoteModule"
+            ):
+                remote_module.named_modules()
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``requires_grad_`` not supported for RemoteModule"
+            ):
+                remote_module.requires_grad_()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``zero_grad`` not supported for RemoteModule"
+            ):
+                remote_module.zero_grad()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``share_memory`` not supported for RemoteModule"
+            ):
+                remote_module.share_memory()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``extra_repr`` not supported for RemoteModule"
+            ):
+                remote_module.extra_repr()
+
+    @dist_utils.dist_init
+    def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # If a new attribute is added to this RemoteModule after the initialization,
+        # and it will be sent over the wire by RPC,
+        # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES.
+        # Note that adding a new attribute out of constructor should rarely happen.
+        # If a new attribute is added to RemoteModule constructor,
+        # there is a sanity check to enforce developers to add this attribute to either
+        # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            new_attr_name = "new_attr"
+            setattr(remote_module, new_attr_name, 1)
+
+            attrs = rpc.rpc_sync(
+                dst_worker_name, remote_module_attributes, (remote_module,)
+            )
+            self.assertNotIn(new_attr_name, attrs)
+
+    @dist_utils.dist_init
+    def test_remote_module_py_pickle_not_supported(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            with TemporaryFileName() as fname:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC",
+                ):
+                    torch.save(remote_module, fname)
+
+    @dist_utils.dist_init
+    def test_remote_module_py_pickle_not_supported_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+        ):
+            with (
+                TemporaryFileName() as fname,
+                self.assertRaisesRegex(
+                    torch.jit.Error, "can only be pickled when using RPC"
+                ),
+            ):
+                torch.save(remote_module, fname)
+
+
+class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest):
+    @property
+    def world_size(self):  # Override setting in CommonRemoteModuleTest
+        return 3
+
+    @dist_utils.dist_init
+    def test_send_remote_module_over_the_wire(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Unpickled attributes include both the inherent attributes of RemoteModule
+        # (not inherited from the superclass) and two installed methods.
+        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
+        expected_unpickled_attrs.append("forward_async")
+        expected_unpickled_attrs.append("forward")
+
+        # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            # Test querying some simple attributes from worker2.
+            attrs = rpc.rpc_sync(
+                dst_worker2_name, remote_module_attributes, (remote_module,)
+            )
+            self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs)
+            self.assertEqual(attrs["on"], "worker1")
+            self.assertEqual(attrs["device"], "cpu")
+            self.assertFalse(attrs["is_device_map_set"])
+            self.assertFalse(attrs["is_scriptable"])
+
+            # Test the installed methods on worker1's can be initiated by worker2 over RPC layer.
+            # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``,
+            # not have another worker to initiate forward over the RPC layer.
+            args = (torch.ones(1), 2, "3")
+            ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args))
+            self.assertEqual(ret1, tuple(reversed(args)))
+            ret2 = rpc.rpc_sync(
+                dst_worker2_name, remote_forward_async, (remote_module, args)
+            )
+            self.assertEqual(ret2, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_send_remote_module_over_the_wire_script_not_supported(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Unpickled attributes include both the inherent attributes of RemoteModule
+        # (not inherited from the superclass) and two installed methods.
+        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
+        expected_unpickled_attrs.append("forward_async")
+        expected_unpickled_attrs.append("forward")
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Passing a script RemoteModule over RPC is not supported."
+        ):
+            # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
+            for remote_module in self._create_remote_module_iter(
+                dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            ):
+                # Test querying some simple attributes from worker2.
+                rpc.rpc_sync(
+                    dst_worker2_name, remote_module_attributes, (remote_module,)
+                )
+
+    @dist_utils.dist_init
+    def test_create_remote_module_from_module_rref(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Create a remote module on worker1 and then pass its `module_rref` to worker2 over the RPC layer.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            remote_module2 = rpc.rpc_sync(
+                dst_worker2_name,
+                RemoteModule.init_from_module_rref,
+                (dst_worker2_name, remote_module.get_module_rref()),
+            )
+
+            args = (torch.ones(1), 2, "3")
+            ret1 = rpc.rpc_sync(dst_worker1_name, remote_forward, (remote_module, args))
+            ret2 = rpc.rpc_sync(
+                dst_worker2_name, remote_forward, (remote_module2, args)
+            )
+            self.assertEqual(ret1, ret2)
+
+
+class CudaRemoteModuleTest(CommonRemoteModuleTest):
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_valid_device(self):
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker_name = dist_utils.worker_name(dst_rank)
+
+        for remote_module in self._create_remote_module_iter(
+            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            device = rpc.rpc_sync(
+                dst_worker_name, remote_device, (remote_module.module_rref,)
+            )
+            self.assertEqual(device.type, "cuda")
+            self.assertEqual(device.index, 0)
+
+        # Test rank works as well.
+        for remote_module in self._create_remote_module_iter(
+            f"rank:{dst_rank}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            device = rpc.rpc_sync(
+                dst_worker_name, remote_device, (remote_module.module_rref,)
+            )
+            self.assertEqual(device.type, "cuda")
+            self.assertEqual(device.index, 0)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_invalid_devices(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            r"Expected one of .+ device type at start of device string",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/foo",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        if TEST_WITH_ROCM:
+            errorString = (
+                r"HIP error: invalid device ordinal\n"
+                r"HIP kernel errors might be asynchronously reported at some other API call, "
+                r"so the stacktrace below might be incorrect.\n"
+                r"For debugging consider passing AMD_SERIALIZE_KERNEL=3"
+            )
+        else:
+            errorString = r"CUDA error: invalid device ordinal"
+        with self.assertRaisesRegex(RuntimeError, errorString):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cuda:100",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cpu2",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cuda:0/cuda:1",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: /. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    "/",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: /cuda:0. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    "/cuda:0",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_input_moved_to_cuda_device(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # These two CPU tensors (in args and kwargs) should be implicitly moved to an appropriate cuda device.
+        t1 = torch.ones(1)
+        args = (t1, 2)
+        t2 = t1 * 2
+        kwargs = dict(word=t2)
+
+        # Only test Python nn.Module, because script module methods don't support taking kwargs.
+        for remote_module in self._create_remote_module_iter(
+            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            ret_fut = remote_module.forward_async(*args, **kwargs)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args + (t2,))))
+            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+            self.assertEqual(ret[0].device.type, "cpu")
+            self.assertEqual(ret[2].device.type, "cpu")
+
+            ret = remote_module.forward(*args, **kwargs)
+            self.assertEqual(ret, tuple(reversed(args + (t2,))))
+            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+            self.assertEqual(ret[0].device.type, "cpu")
+            self.assertEqual(ret[2].device.type, "cpu")
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_input_moved_to_cuda_device_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                f"{dst_worker_name}/cuda:0",
+                modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE],
+            )
+        )
+
+        @torch.jit.script
+        def run_forward(scripted_remote_module: MyModuleInterface):
+            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
+            return ret
+
+        ret = run_forward(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+        # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+        self.assertEqual(ret[2].device.type, "cpu")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..313a0d0ec9b972b5f95661e37c12161f33615512
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6219189a3534903aa60d227678c63d598d8b5548
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b09b6b6bfc034010c7398bf330e9a7b305139d66
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4296d6de67c821c31ddf0f3f63420c255cabb77
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03516c33942aa4792882d1bfcb2cee5a33ea7867
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..753cb9d67e6bd5efa42da5e898e51b970edbf020
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1abadd33309da7c933ea03ec300e67d05d343600
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
@@ -0,0 +1,2756 @@
+# mypy: allow-untyped-defs
+
+import random
+import sys
+import threading
+import time
+from datetime import timedelta
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+import torch.testing._internal.dist_utils
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.distributed.rpc import RRef
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import (
+    IS_MACOS,
+    skip_but_pass_in_sandcastle_if,
+)
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    initialize_pg,
+    wait_until_node_failure,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+# Right now we test up to 3-layer nested rpc calls.
+# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id
+# sent from prev rank respectively.
+# rpc_done[2] and ctx_ids[2] represents for prev of prev rank.
+# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank.
+# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used.
+rpc_done = [False, False, False, False]
+ctx_ids = [-1, -1, -1, -1]
+
+known_context_ids = set()
+
+requires_grad_tensor = torch.ones(3, 3, requires_grad=True)
+
+
+# Send rpc done info and context_id to
+# dst_rank = (self.rank + rank_distance) % self.world_size
+# we don't need a lock here since the GIL is held while executing remote
+# python UDFs, so access is serialized across several workers.
+def _set_rpc_done(ctx_id, rank_distance):
+    global rpc_done
+    global ctx_ids
+    global known_context_ids
+    rpc_done[rank_distance] = True
+    ctx_ids[rank_distance] = ctx_id
+    known_context_ids.add(ctx_id)
+
+
+def _check_rpc_done(rank_distance):
+    while not rpc_done[rank_distance]:
+        time.sleep(0.1)
+
+
+def _torch_ones(sizes, requires_grad=False):
+    return torch.ones(sizes, requires_grad=requires_grad)
+
+
+# This method must be called on the rref owner, and verifies that the grad of
+# rref tensor equals to the given grad.
+def _compare_owner_value(context_id, rref, grad):
+    grads = dist_autograd.get_gradients(context_id)
+    x = grads[rref.local_value()]
+    if x.is_sparse:
+        assert grad.is_sparse
+        x = x.to_dense()
+        grad = grad.to_dense()
+    else:
+        assert not grad.is_sparse
+    return torch.equal(x, grad)
+
+
+def create_tensor():
+    return torch.ones((3, 3), requires_grad=True)
+
+
+def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32):
+    i = [[0, 1, 1], [2, 0, 2]]
+    v = [3.2, 4.1, 5.3]
+    tensor = torch.sparse_coo_tensor(
+        i, v, (3, 3), requires_grad=requires_grad, dtype=dtype
+    )
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
+
+
+@torch.jit.script
+def create_torchscript_tensor() -> torch.Tensor:
+    return torch.ones((3, 3)).requires_grad_()
+
+
+def my_py_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+def my_scalar_add(a, b):
+    return a + b
+
+
+def my_rref_add(rref_t1, t2):
+    ret = torch.add(rref_t1.local_value(), t2)
+    return ret
+
+
+@torch.jit.script
+def my_script_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+@torch.jit.script
+def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor:
+    t1 = ref_t1.to_here()
+    return torch.add(t1, t2)
+
+
+def my_nested_rref_add(dst, rref_t1, t2):
+    return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
+
+
+def ret_requires_grad():
+    return requires_grad_tensor
+
+
+def my_py_nested_call(t1, t2, dst, world_size, hops):
+    next_dst = (dst + 1) % world_size
+    if hops > 0:
+        return rpc.rpc_sync(
+            worker_name(next_dst),
+            my_py_nested_call,
+            args=(t1, t2, next_dst, world_size, hops - 1),
+        )
+    else:
+        return rpc.rpc_sync(worker_name(next_dst), my_py_add, args=(t1, t2))
+
+
+# after dist autograd context is cleaned up, it should be cleaned up on other
+# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
+# ensures that all the contexts have been cleaned up in that timeframe.any
+def _all_contexts_cleaned_up(timeout_seconds=10):
+    global known_context_ids
+    start = time.time()
+    context_id_to_raised = set()
+    while (
+        time.time() - start < timeout_seconds
+        and context_id_to_raised != known_context_ids
+    ):
+        for context_id in known_context_ids:
+            try:
+                dist_autograd._retrieve_context(context_id)
+            except RuntimeError:
+                context_id_to_raised.add(context_id)
+    # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
+    success = context_id_to_raised == known_context_ids
+    return success
+
+
+# This function creates a dis autograd context, run rpc_sync on the given ps,
+# and then blocks until the ps has verified the grads are correctly accumulated.
+def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
+    with dist_autograd.context() as context_id:
+        ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
+        # prevent deleting dist autograd context
+        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
+        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
+
+
+# This function is the same as _run_trainer, except rpc calls torchscript
+# function "my_script_ref_add" instead of python function "my_rref_add"
+def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
+    with dist_autograd.context() as context_id:
+        ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
+        # prevent deleting dist autograd context
+        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
+        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
+
+
+class SimulateBackwardError(Function):
+    _simulate_error = True
+
+    @staticmethod
+    def forward(ctx, input):
+        return input
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, input):
+        if SimulateBackwardError._simulate_error:
+            raise Exception("Simulate error on backward pass")  # noqa: TRY002
+        else:
+            return input
+
+
+class ExecMode(Enum):
+    LOCAL = 1  # Run the operation locally.
+    RPC_SYNC = 2  # Run the operation using rpc_sync
+    REMOTE = 3  # Run the operation using remote.
+    RPC_ASYNC = 4  # Run the operation using rpc_async
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonDistAutogradTest(RpcAgentTestFixture):
+    def _exec_func_with_dst(self, dst, exec_mode, method, *args):
+        if ExecMode.LOCAL == exec_mode:
+            if len(args) == 1 and isinstance(args[0], list):
+                return method(*args[0])
+            return method(*args)
+        elif ExecMode.RPC_SYNC == exec_mode:
+            return rpc.rpc_sync(worker_name(dst), method, args=(args))
+        elif ExecMode.REMOTE == exec_mode:
+            return rpc.remote(worker_name(dst), method, args=(args)).to_here()
+        elif ExecMode.RPC_ASYNC == exec_mode:
+            fut = rpc.rpc_async(worker_name(dst), method, args=(args))
+            return fut.wait()
+        else:
+            raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+    def _exec_func(self, exec_mode, method, *args):
+        return self._exec_func_with_dst(self._next_rank(), exec_mode, method, *args)
+
+    def _next_rank(self):
+        if hasattr(self, "dst_rank"):
+            self.dst_rank = (self.dst_rank + 1) % self.world_size
+            if self.dst_rank == self.rank:
+                return self._next_rank()
+        else:
+            self.dst_rank = (self.rank + 1) % self.world_size
+        return self.dst_rank
+
+    def _check_rpc_done(self, rank_distance):
+        _check_rpc_done(rank_distance)
+
+    def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args):
+        if exec_mode == ExecMode.LOCAL:
+            torch.autograd.backward(tensors)
+            return [arg.grad for arg in args]
+        else:
+            self._verify_backwards_remote(tensors, context_id, local_grads, *args)
+
+    def _verify_backwards_remote(self, tensors, context_id, local_grads, *args):
+        dist_autograd.backward(context_id, tensors)
+
+        # Verify grads were accumulated appropriately.
+        grads = dist_autograd.get_gradients(context_id)
+        nargs = len(args)
+        ngrads = 0
+        for i in range(nargs):
+            if local_grads[i] is not None:
+                self.assertIn(args[i], grads)
+                self.assertEqual(local_grads[i], grads[args[i]])
+                ngrads += 1
+            else:
+                self.assertNotIn(args[i], grads)
+
+        self.assertEqual(ngrads, len(grads))
+
+    def _test_graph(self, fn, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor()
+                t2 = build_sparse_tensor()
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(worker_name(dst_rank), fn, args=(t1, t2)).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            # Verify graph for current context id.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(1, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                next(iter(recv_functions.values())),
+                t1,
+                t2,
+                ret,
+            )
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            # Verify graph for previous context id.
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+        # autograd context should be cleaned up by now.
+        with self.assertRaises(RuntimeError):
+            ctx = dist_autograd._retrieve_context(context_id)
+
+        # No autograd context available.
+        with self.assertRaises(RuntimeError):
+            ctx = dist_autograd._current_context()
+
+    # 3-layer nested calls
+    def _test_graph_for_py_nested_call(self, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(t1, t2, dst_rank, self.world_size, 1),
+                )
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(t1, t2, dst_rank, self.world_size, 1),
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            # Barrier to ensure all RPCs are done.
+            dist.barrier()
+
+            for rd in [1, 2, 3]:
+                rpc.rpc_sync(
+                    worker_name((self.rank + rd) % self.world_size),
+                    _set_rpc_done,
+                    args=(context_id, rd),
+                )
+
+            # Barrier to ensure all set_rpc_done have completed.
+            dist.barrier()
+
+            # For self.rank, it has 4 graphs to verify
+            # One is for current context id when this rank send first rpc call.
+            # Second one is for prev context id when this rank make 1st nested
+            # call.
+            # Third one is for prev prev context id when this rank make
+            # 2nd nested call.
+            # Last one is for prev prev prev context id when this rank
+            # execute the torch.add() operator.
+
+            # Verify first graph for current context id.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(1, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                next(iter(recv_functions.values())),
+                t1,
+                t2,
+                ret,
+            )
+
+            # Verify second graph for 1st nested call.
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            self._verify_graph_for_nested_rpc_call(ctx)
+
+            # Verify third graph for 2nd nested call.
+            ctx = dist_autograd._retrieve_context(ctx_ids[2])
+            self._verify_graph_for_nested_rpc_call(ctx)
+
+            # verify last graph for rpc call execution.
+            ctx = dist_autograd._retrieve_context(ctx_ids[3])
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+    # Rank0->Rank1->Rank0
+    def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(
+                        t1,
+                        t2,
+                        (self.rank - 1 + self.world_size) % self.world_size,
+                        self.world_size,
+                        0,
+                    ),
+                )
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(
+                        t1,
+                        t2,
+                        (self.rank - 1 + self.world_size) % self.world_size,
+                        self.world_size,
+                        0,
+                    ),
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(
+                worker_name((self.rank + 1) % self.world_size),
+                _set_rpc_done,
+                args=(context_id, 1),
+            )
+
+            # For self.rank, it has 2 graphs to verify.
+            # One is for current context id when this rank send first rpc
+            # call and execute the torch.add() operator.
+            # Another one is for prev context id when this rank make
+            # nested call.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(2, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(2, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                list(recv_functions.values())[1],
+                t1,
+                t2,
+                ret,
+            )
+            self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])
+
+            # Verify two pairs of send and recv functions for nested
+            # call
+            self._check_rpc_done(1)
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            self._verify_graph_for_nested_rpc_call(ctx)
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dst_rank = (self.rank + 1) % self.world_size
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=False)
+                t2 = build_sparse_tensor(requires_grad=False)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=False)
+                t2 = torch.zeros(3, 3, requires_grad=False)
+            if ExecMode.RPC_SYNC == exec_mode:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+            elif ExecMode.REMOTE == exec_mode:
+                rpc.remote(worker_name(dst_rank), torch.add, args=(t1, t2)).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            ctx = dist_autograd._current_context()
+            send_functions = ctx._send_functions()
+            self.assertEqual(len(send_functions), 0)
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(len(recv_functions), 0)
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            # NB: RRef.to_here() always passes the autograd context to the
+            # the callee, as the caller does not know whether the return
+            # value would contain a requires_grad tensor or not.
+            #
+            # rpc/remote with udf (_set_rpc_done here) also always passes the
+            # autograd context to the callee due to the same reason.
+            self.assertNotEqual(-1, dist_autograd._retrieve_context(ctx_ids[1]))
+            dist.barrier()
+
+    def _test_rpc_complex_args(self, exec_mode, sparse):
+        with dist_autograd.context():
+            num_tensors = 10
+            tensors = []
+            for i in range(num_tensors):
+                if sparse:
+                    tensor = build_sparse_tensor(requires_grad=(i % 2 == 0))
+                else:
+                    tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0))
+                tensors.append(tensor)
+            dst_rank = self._next_rank()
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), torch.stack, args=(tensors,))
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank), torch.stack, args=(tensors,)
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            self.assertEqual(torch.stack(tensors), ret)
+
+            # Verify appropriate tensors have been attached the autograd graph.
+            next_funcs = next(
+                iter(dist_autograd._current_context()._send_functions().values())
+            ).next_functions
+            for i in range(len(next_funcs)):
+                self.assertEqual(
+                    "torch::autograd::AccumulateGrad", next_funcs[i][0].name()
+                )
+                self.assertEqual(tensors[i], next_funcs[i][0].variable)
+
+            # Verify that the worker id has been recorded in the context
+            ctx = dist_autograd._current_context()
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(len(worker_ids), 1)
+            self.assertEqual(worker_ids, {dst_rank})
+
+    def context_cleanup_test_helper(self, rpc_args, func, nested=False):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # test that in dist autograd, in the case that tensors communicated over RPC do
+        # NOT require grad, we still cleanup the dist autograd contexts created
+        # on other nodes. This is because the autograd context is still
+        # communicated over RPC even if tensor arguments do not require grad, as
+        #  it is possible that the response could.
+        if nested:
+            dst_rank = (self.rank + 1) % self.world_size
+            nested_dst_rank = (dst_rank + 1) % self.world_size
+            dst_ranks = {dst_rank}
+        else:
+            dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+
+        with dist_autograd.context() as context_id:
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+                if nested:
+                    rpc.rpc_sync(
+                        worker_name(nested_dst_rank),
+                        _set_rpc_done,
+                        args=(context_id, 2),
+                    )
+        # the thread's context id should be cleaned up
+        with self.assertRaises(RuntimeError):
+            dist_autograd._retrieve_context(context_id)
+        # Ensure all peers have finished mutating the
+        # `known_context_ids` set.
+        dist.barrier()
+        # check that all contexts have been cleaned up.
+        success = _all_contexts_cleaned_up()
+        self.assertTrue(success)
+
+    def _backward_no_grad_on_tensor(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            self.assertIsNone(t1.grad)
+            self.assertIsNone(t2.grad)
+
+            # Now populate .grad with local autograd engine and
+            # verify dist autograd doesn't mess with it.
+            loss_local = torch.add(t1, t2)
+            if sparse:
+                loss_local = torch.sparse.sum(loss_local)
+            else:
+                loss_local = loss_local.sum()
+            loss_local.backward()
+            self.assertIsNotNone(t1.grad)
+            self.assertIsNotNone(t2.grad)
+
+            t1_grad_before = t1.grad
+            t2_grad_before = t2.grad
+            dist_autograd.backward(context_id, [loss])
+            self.assertEqual(t1_grad_before, t1.grad)
+            self.assertEqual(t2_grad_before, t2.grad)
+
+    # The current rank first creates a tensor on the rref_owner, and then passes
+    # the rref with another tensor to the callee to run either my_rref_add or
+    # my_nested_rref_add, depending on whether the callee is the rref owner.
+    # The grad of tensor lives on the current rank, and the grad of the rref
+    # tensor lives on the rref owner.
+    def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse):
+        local_ret = torch.add(t1, t2)
+        if sparse:
+            local_ret = torch.sparse.sum(local_ret)
+        else:
+            local_ret = local_ret.sum()
+        local_ret.backward()
+        with dist_autograd.context() as context_id:
+            if sparse:
+                rref_t1 = rpc.remote(
+                    rref_owner,
+                    build_sparse_tensor,
+                    args=(
+                        False,
+                        True,
+                    ),
+                )
+            else:
+                rref_t1 = rpc.remote(
+                    rref_owner,
+                    _torch_ones,
+                    args=((3, 3),),
+                    kwargs={"requires_grad": True},
+                )
+            if callee == rref_owner:
+                rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
+            else:
+                rref = rpc.remote(
+                    callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
+                )
+            ret = rref.to_here()
+            if sparse:
+                ret = torch.sparse.sum(ret)
+            else:
+                ret = ret.sum()
+            dist_autograd.backward(context_id, [ret])
+
+            # verify grads on caller
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertIn(t2, grads)
+            self.assertEqual(grads[t2], t2.grad)
+
+            # verify grads on rref owner
+            self.assertTrue(
+                rpc.rpc_sync(
+                    rref_owner,
+                    _compare_owner_value,
+                    args=(context_id, rref_t1, t1.grad),
+                )
+            )
+
+    # In this test, every rank will serve as a parameter server (ps) and a
+    # driver, and then kicks off trainers on the other three ranks. So, we have:
+    # ps = rank0 with trainers = rank1/2/3
+    # ps = rank2 with trainers = rank2/3/0
+    # ps = rank3 with trainers = rank3/0/1
+    # ps = rank4 with trainers = rank0/1/2
+    #
+    # These four test ps-trainer groups run on completely separate autograd
+    # graphs, but they share the same set of underlying RpcAgents.
+    def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse):
+        if sparse:
+            t1 = build_sparse_tensor(requires_grad=True)
+            t2 = build_sparse_tensor(requires_grad=True)
+        else:
+            t1 = torch.ones((3, 3), requires_grad=True)
+            t2 = torch.zeros((3, 3), requires_grad=True)
+
+        local_ret = torch.add(t1, t2)
+        if sparse:
+            torch.sparse.sum(local_ret).backward()
+        else:
+            local_ret.sum().backward()
+
+        # create rref on self
+        rref_t1 = rpc.remote(worker_name(self.rank), create_ref_fn, args=())
+
+        # kick off forward and backward pass on three other workers (trainers)
+        rank_diffs = [1, 2, 3]
+        futures = [
+            rpc.rpc_async(
+                worker_name((self.rank + rank_diff) % self.world_size),
+                trainer_fn,
+                args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
+            )
+            for rank_diff in rank_diffs
+        ]
+
+        # check if the trainers have done with their backward pass
+        for rank_diff in rank_diffs:
+            self._check_rpc_done(rank_diff)
+
+        # trainers are done and holding the context for verification
+        for rank_diff in rank_diffs:
+            # make sure grads are accumulated for the same tensors and values
+            # are all correct
+            ctx_id = ctx_ids[rank_diff]
+            grads = dist_autograd.get_gradients(ctx_id)
+            local_t1 = rref_t1.to_here()
+            self.assertIn(local_t1, grads)
+            self.assertEqual(grads[local_t1], t1.grad)
+
+        # unblock trainers
+        _set_rpc_done(None, 0)
+
+        # wait until all trainers are done
+        torch.futures.wait_all(futures)
+
+    def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse):
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                # Multiple RPCs between different nodes.
+                val = self._exec_func(exec_mode, torch.add, t1, t2)
+                val = self._exec_func(exec_mode, torch.mul, t3, val)
+                s1 = self._exec_func(exec_mode, torch.stack, (t4, val))
+                s2 = self._exec_func(exec_mode, torch.stack, (t5, val))
+                if sparse:
+                    val = self._exec_func(exec_mode, torch.mul, s1, s2)
+                    val = self._exec_func(exec_mode, torch.mul, val, val)
+                    loss = torch.sparse.sum(val)
+                else:
+                    val = self._exec_func(exec_mode, torch.bmm, s1, s2)
+                    val = self._exec_func(exec_mode, torch.matmul, val, val)
+                    loss = val.sum()
+
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5
+                )
+                local_grads = ret if ret else local_grads
+
+    def _backward_different_dtypes(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                loss = self._exec_func(exec_mode, torch.add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(loss)
+                else:
+                    loss = loss.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_python_udf(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(exec_mode, my_py_add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_script_call(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [
+            ExecMode.LOCAL,
+            ExecMode.RPC_SYNC,
+            ExecMode.RPC_ASYNC,
+            ExecMode.REMOTE,
+        ]:
+            with dist_autograd.context() as context_id:
+                forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(forward_ret)
+                else:
+                    loss = forward_ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    def _nested_backward_accumulate_grads(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            ret = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._test_nested_backward_accumulate_grads,
+                args=(t1, t2, self._next_rank()),
+            )
+            if sparse:
+                loss = torch.sparse.sum(ret)
+            else:
+                loss = ret.sum()
+            # Run backward twice.
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            dist_autograd.backward(context_id, [loss])
+
+    def _backwards_nested_python_udf(self, t1, t2, sparse):
+        t3 = t1 * t2
+        t4 = t1 + t2
+        res = t3 + t4
+        loss = t1 * t2 * t3 * t4 * res
+        if sparse:
+            loss = torch.sparse.sum(loss)
+        else:
+            loss = loss.sum()
+        torch.autograd.backward([loss])
+
+        # Now run distributed autograd.
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._nested_python_udf,
+                args=(t1, t2, self._next_rank()),
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            dist_autograd.backward(context_id, [loss])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    def _mixed_requires_grad(self, t1, t2, sparse):
+        for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(
+                    exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2
+                )
+                self.assertEqual(t1 * t2, ret)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                dist_autograd.backward(context_id, [loss])
+                self.assertTrue(t1.requires_grad)
+                self.assertFalse(t2.requires_grad)
+                grads = dist_autograd.get_gradients(context_id)
+                self.assertIn(t1, grads)
+                self.assertNotIn(t2, grads)
+                self.assertEqual(t2, grads[t1])
+
+    def _multiple_backward(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            # Run backward in a loop multiple times.
+            for _ in range(1000):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+    # For current context, this rank sends t1 and t2 tensors to dst_rank,
+    # then get t3 = torch.add(t1, t2) result tensor.
+    # For the current context in this rank, it expects graph like this:
+    #  send function:
+    #              rpcSendBackward
+    #                  /          \
+    #  t1.AccumulateGrad         t2.AccumulateGrad
+    #
+    #  recv function:
+    #
+    #            |
+    #          t3.rpcRecvBackward
+    #
+    def _verify_graph_for_first_rpc_call(
+        self, send_function, recv_function, t1, t2, ret
+    ):
+        # Retrieve the next functions in the graph.
+        next_funcs = send_function.next_functions
+        self.assertEqual(2, len(next_funcs))
+
+        # We should now hit t1 and t2 in the autograd graph.
+        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name())
+        self.assertEqual(t1, next_funcs[0][0].variable)
+        self.assertEqual(0, next_funcs[0][1])
+        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name())
+        self.assertEqual(t2, next_funcs[1][0].variable)
+        self.assertEqual(0, next_funcs[1][1])
+
+        # Test recv functions.
+        self.assertEqual(ret.grad_fn, recv_function)
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple(self, dst, t1, t2, local_grads, sparse):
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func_with_dst(dst, exec_mode, torch.add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    # For a context passed from previous nested chain calls, this rank
+    # receives two tensors t1 and t2, executes torch.add(t1, t2) and sends
+    # result tensor t3 back.
+    # For this context in this rank, it expects graph like this:
+    #  send and recv functions:
+    #       rpcSendBackward
+    #           |
+    #          t3.AddBackward0
+    #          /             \
+    # t1.recvRpcBackward    t2.recvRpcBackward
+    def _verify_graph_for_rpc_call_exec(self, send_function):
+        # Verify next function is AddBackward0
+        next_funcs = send_function.next_functions
+        self.assertEqual(1, len(next_funcs))
+        add_backward_fn = next_funcs[0][0]
+        self.assertEqual("AddBackward0", add_backward_fn.name())
+
+        # Verify the next two functions are the same recv backward function.
+        next_funcs = add_backward_fn.next_functions
+        self.assertEqual(2, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
+        )
+        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
+
+    # For a context passed from previous nested chain calls, this rank
+    # receives two tensors t1 and t2, forwards t1 and t2 tensors using
+    # nested rpc call to next dst. In return route, receive result tensor t3
+    # from next dst and forwarding t3 back to previous calls.
+    # For this context in this rank, it expects graph like this:
+    #  send and recv functions for receiving and forwarding t1 and t2:
+    #       rpcSendBackward
+    #          /          \
+    # t1.recvRpcBackward    t2.recvRpcBackward
+    #  send and recv functions for receiving and forwarding t3:
+    #       rpcSendBackward
+    #             |
+    #           t3.recvRpcBackward
+    def _verify_graph_for_nested_rpc_call(self, ctx):
+        send_functions = ctx._send_functions()
+        self.assertEqual(2, len(send_functions))
+
+        # For send function when making nest rpc call,
+        # next functions of the send function are two recv functions
+        # for received two tensors from previous call
+        next_funcs = next(iter(send_functions.values())).next_functions
+        self.assertEqual(2, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
+        )
+        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
+
+        # For send function when returning response to previous call
+        # next function of the send function is the recv function
+        # for received tensor result returned from nested call
+        next_funcs = list(send_functions.values())[1].next_functions
+        self.assertEqual(1, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+
+
+class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
+    # Sparse tests only work with TensorPipeAgent.
+    @dist_init
+    def test_graph_for_builtin_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_python_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_builtin_remote_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_python_remote_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_py_nested_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_py_nested_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_remote_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_rpc_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_remote_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_context_cleanup_tensor_with_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_tensor_no_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=False)
+        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_nested_rpc_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        dst_rank = (self.rank + 1) % self.world_size
+        args = (t1, t2, dst_rank, self.world_size, 0)
+        self.context_cleanup_test_helper(
+            rpc_args=args, func=my_py_nested_call, nested=True
+        )
+
+    @dist_init
+    def test_backward_no_grad_on_tensor_sparse(self):
+        self._backward_no_grad_on_tensor(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_sparse(self):
+        self._backward_simple(
+            self._next_rank(),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_self_sparse(self):
+        self._backward_simple(
+            self.rank,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_rref_multi_sparse(self):
+        if self.rank > 0:
+            callee = "worker0"
+            rref_owner = callee
+            self._backward_rref(
+                callee,
+                rref_owner,
+                build_sparse_tensor(requires_grad=True),
+                build_sparse_tensor(requires_grad=True),
+                None,
+                True,
+            )
+
+    @dist_init
+    def test_backward_rref_sparse(self):
+        callee = worker_name(self._next_rank())
+        rref_owner = callee
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_rref_nested_sparse(self):
+        callee = worker_name((self.rank + 1) % self.world_size)
+        rref_owner = worker_name((self.rank + 2) % self.world_size)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_trainer_ps_sparse(self):
+        self._test_trainer_ps(build_sparse_tensor, _run_trainer, True)
+
+    @dist_init
+    def test_backward_multiple_round_trips_sparse(self):
+        self._backward_multiple_round_trips(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_different_dtypes_sparse(self):
+        self._backward_different_dtypes(
+            build_sparse_tensor(requires_grad=True, dtype=torch.float32),
+            build_sparse_tensor(requires_grad=True, dtype=torch.float64),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_python_udf_sparse(self):
+        self._backward_simple_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_script_call_sparse(self):
+        self._backward_simple_script_call(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_nested_backward_accumulate_grads_sparse(self):
+        self._nested_backward_accumulate_grads(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backwards_nested_python_udf_sparse(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_mixed_requires_grad_sparse(self):
+        self._mixed_requires_grad(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            True,
+        )
+
+    @dist_init
+    def test_multiple_backward_sparse(self):
+        self._multiple_backward(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_embedding_bag_with_no_grad_tensors(self):
+        dst = self._next_rank()
+        remote_embedding = rpc.remote(
+            worker_name(dst),
+            torch.nn.EmbeddingBag,
+            args=(16, 16),
+            kwargs={"mode": "sum", "sparse": True},
+        )
+        local_embedding = torch.nn.EmbeddingBag(16, 16, mode="sum", sparse=True)
+
+        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
+        # requires_grad = True to record send/recv functions
+        per_sample_weights = torch.rand((8), requires_grad=True)
+        offsets = torch.LongTensor([0, 4])
+
+        local_res = local_embedding(input, offsets, per_sample_weights)
+
+        # Run backward twice.
+        torch.autograd.backward([local_res.sum()], retain_graph=True)
+        torch.autograd.backward([local_res.sum()])
+        local_grad = local_embedding.weight.grad
+
+        with dist_autograd.context() as context_id:
+            res = rpc.rpc_sync(
+                worker_name(dst),
+                DistAutogradTest._call_remote_embedding,
+                args=(remote_embedding, input, offsets, per_sample_weights),
+            )
+
+            # Run backward twice to test accumulation of sparse gradients.
+            dist_autograd.backward(context_id, [res.sum()], retain_graph=True)
+            dist_autograd.backward(context_id, [res.sum()])
+
+            remote_grad = rpc.rpc_sync(
+                worker_name(dst),
+                DistAutogradTest._get_grad,
+                args=(remote_embedding, context_id),
+            )
+
+            self.assertEqual(local_grad, remote_grad)
+
+
+class DistAutogradTest(CommonDistAutogradTest):
+    @dist_init
+    def test_autograd_context(self):
+        # Verify max possible id.
+        max_auto_increment = 281474976710655
+        self.assertEqual(
+            max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id()
+        )
+
+        context_ids = []
+        for _ in range(200):
+            with dist_autograd.context() as context_id:
+                self.assertEqual(
+                    context_id,
+                    dist_autograd._retrieve_context(context_id)._context_id(),
+                )
+                # First 16 bits should be worker_id.
+                self.assertEqual(self.worker_id, context_id >> 48)
+                context_ids.append(context_id)
+
+        for context_id in context_ids:
+            with self.assertRaisesRegex(
+                RuntimeError,
+                f"Could not find autograd context with id: {context_id}",
+            ):
+                dist_autograd._retrieve_context(context_id)
+
+    @dist_init
+    def test_nested_context(self):
+        with (
+            dist_autograd.context(),
+            self.assertRaisesRegex(
+                RuntimeError, "Already have an autograd context id for this thread"
+            ),
+            dist_autograd.context(),
+        ):
+            pass
+
+    @dist_init
+    def test_graph_for_builtin_call(self):
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_python_call(self):
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_builtin_remote_call(self):
+        self._test_graph(torch.add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_python_remote_call(self):
+        self._test_graph(my_py_add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call(self):
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call(self):
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call_itself(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_itself(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_remote(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False)
+
+    def _test_grad_only_on_return_value(self, exec_mode):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dst_rank = (self.rank + 1) % self.world_size
+        with dist_autograd.context() as context_id:
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad)
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(worker_name(dst_rank), ret_requires_grad).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            dist_autograd.backward(context_id, [ret.sum()])
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            grads = dist_autograd.get_gradients(ctx_ids[1])
+            self.assertEqual(1, len(grads))
+            self.assertIn(requires_grad_tensor, grads)
+            self.assertEqual(torch.ones_like(ret), grads[requires_grad_tensor])
+            # due to the above get_gradients call, ensure that dist autograd
+            # contexts aren't cleaned up until all workers exit context managers
+            dist.barrier()
+
+    @dist_init
+    def test_grad_only_on_return_value(self):
+        self._test_grad_only_on_return_value(ExecMode.RPC_SYNC)
+
+    @dist_init
+    def test_grad_only_on_return_value_remote(self):
+        self._test_grad_only_on_return_value(ExecMode.REMOTE)
+
+    @dist_init
+    def test_rpc_complex_args(self):
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_remote_complex_args(self):
+        self._test_rpc_complex_args(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_context_cleanup_tensor_with_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_tensor_no_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=False)
+        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_no_tensors(self):
+        self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add)
+
+    @dist_init
+    def test_context_cleanup_nested_rpc(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        dst_rank = (self.rank + 1) % self.world_size
+        args = (t1, t2, dst_rank, self.world_size, 0)
+        self.context_cleanup_test_helper(
+            rpc_args=args, func=my_py_nested_call, nested=True
+        )
+
+    @dist_init
+    def test_worker_ids_recorded(self):
+        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+        with dist_autograd.context() as context_id:
+            # if no tensors require grad, we should still record worker_ids, as
+            # the autograd context ID is still passed to other workers.
+            t1 = torch.ones(3, 3, requires_grad=False)
+            t2 = torch.zeros(3, 3, requires_grad=False)
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+            # all worker_ids in dst_ranks should be recorded.
+            ctx = dist_autograd._current_context()
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(worker_ids, dst_ranks)
+
+            # worker_ids should be recorded when tensors do require grad
+            t1.requires_grad = True
+            t2.requires_grad = True
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+            # all worker_ids in dst_ranks should be recorded.
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(worker_ids, dst_ranks)
+
+    @dist_init
+    def test_dist_autograd_profiling(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(3, 3, requires_grad=True)
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            ).sum()
+            with torch.autograd.profiler.profile() as p:
+                dist_autograd.backward(context_id, [loss])
+
+        function_events = p.function_events
+
+        def get_event(partial_key):
+            return next(event for event in function_events if partial_key in event.name)
+
+        send_event = get_event("SendRpcBackward")
+        recv_event = get_event("RecvRpcBackward")
+        backward_event = get_event("torch::distributed::autograd::backward")
+        # There should be at least 1 send and recv_events each, corresponding to send/recv functions executed.
+        self.assertEqual(send_event.count, 1)
+        self.assertEqual(recv_event.count, 1)
+        # The CPU total for backward event should be great than send and recv, since
+        # applying those functions in the backwards pass is a subset of the entire backward pass.
+        self.assertGreater(backward_event.cpu_time_total, send_event.cpu_time_total)
+        self.assertGreater(backward_event.cpu_time_total, recv_event.cpu_time_total)
+
+    @dist_init
+    def test_error_in_context(self):
+        with dist_autograd.context():
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(6, 6, requires_grad=True)
+
+            with self.assertRaises(RuntimeError):
+                # This should throw an error since matrix sizes don't match.
+                rpc.rpc_sync(
+                    worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
+                )
+
+    @dist_init
+    def test_backward_no_grad_on_tensor(self):
+        self._backward_no_grad_on_tensor(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple(self):
+        self._backward_simple(
+            self._next_rank(),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_self(self):
+        self._backward_simple(
+            self.rank,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_rref(self):
+        callee = worker_name(self._next_rank())
+        rref_owner = callee
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_rref_multi(self):
+        if self.rank > 0:
+            callee = "worker0"
+            rref_owner = callee
+            self._backward_rref(
+                callee,
+                rref_owner,
+                torch.rand((3, 3), requires_grad=True),
+                torch.rand((3, 3), requires_grad=True),
+                None,
+                False,
+            )
+
+    @dist_init
+    def test_backward_rref_nested(self):
+        callee = worker_name((self.rank + 1) % self.world_size)
+        rref_owner = worker_name((self.rank + 2) % self.world_size)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_trainer_ps(self):
+        self._test_trainer_ps(create_tensor, _run_trainer, False)
+
+    @dist_init
+    def test_trainer_ps_torchscript_functions(self):
+        # TODO, need more investigation
+        # there is rref leak when shutting down, suspect it is because
+        # ref as arg is passed to pybind boundary, and the ref is not garbage
+        # collected by python when calling shutdown()
+        import torch.distributed.rpc.api as api
+
+        api._ignore_rref_leak = True
+
+        self._test_trainer_ps(
+            create_torchscript_tensor, _run_trainer_torchscript, False
+        )
+
+    @dist_init
+    def test_backward_multiple_round_trips(self):
+        self._backward_multiple_round_trips(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_different_tensor_dims(self):
+        local_grads = None
+        t1 = torch.rand((4, 6), requires_grad=True)
+        t2 = torch.rand((6, 5))
+        t3 = torch.rand((5, 7), requires_grad=True)
+        t4 = torch.rand((7, 9))
+
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                val = self._exec_func(exec_mode, torch.matmul, t1, t2)
+                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (val, t3, t4))
+                loss = val.sum()
+
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_backward_unused_tensors(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        t3 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3))
+                val = self._exec_func(
+                    exec_mode,
+                    torch.matmul,
+                    torch.narrow(s, 0, 0, 1),
+                    torch.narrow(s, 0, 2, 1),
+                )
+
+                loss = val.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t3
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_backward_multiple_output_tensors(self):
+        local_grads = None
+        t = torch.rand((10, 2), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                tensor_list = self._exec_func(exec_mode, torch.split, t, 2)
+                t1 = tensor_list[0]
+                t2 = tensor_list[2]
+                t3 = tensor_list[4]
+
+                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (t1, t2, t3))
+
+                loss = val.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t
+                )
+                local_grads = ret if ret else local_grads
+
+    def _run_test_backward_unused_send_function_in_thread(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            # We don't use the result of an RPC function, as a result the
+            # backward pass would hang in the "FAST" mode.
+            rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+
+            val = torch.mul(t1, t2)
+
+            # Run backward, this would hang forever.
+            dist_autograd.backward(context_id, [val.sum()])
+
+    @dist_init
+    def test_backward_unused_send_function(self):
+        # Run the test in a thread which would never finish.
+        t = threading.Thread(
+            target=self._run_test_backward_unused_send_function_in_thread
+        )
+        t.daemon = True
+        t.start()
+        t.join(10)  # Wait for 10s.
+
+        # Verify thread is still alive (indicating backward hasn't completed yet).
+        self.assertTrue(t.is_alive())
+
+    @dist_init
+    def test_backward_autograd_engine_error(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            # Perform some ops before error simulation.
+            tmp = (t1 + t2) * (t1 + t2)
+            t3 = SimulateBackwardError.apply(tmp)
+
+            # Run multiple round trips across different nodes and verify the
+            # original node receives an error thrown on a node deep in the chain.
+            val = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t2, t3))
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.mul, args=(val, t2)
+            )
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.matmul, args=(val, t2)
+            )
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.div, args=(val, t2)
+            )
+
+            with self.assertRaisesRegex(
+                RuntimeError, "Error on Node [0-9]+: Simulate error on backward pass"
+            ):
+                # Run backwards, and validate we receive an error.
+                dist_autograd.backward(context_id, [val.sum()])
+
+    @dist_init(clean_shutdown=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_backward_node_failure(self):
+        rpc._set_rpc_timeout(5)  # 5 seconds
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+
+            # Wait for all RPCs to be done.
+            dist.barrier()
+
+            # Kill all odd rank nodes.
+            if self.rank % 2 == 0:
+                shutdown_error_regex = self.get_shutdown_error_regex()
+                # Wait for all other nodes to die.
+                for rank in range(self.world_size):
+                    if rank % 2 != 0:
+                        wait_until_node_failure(rank, shutdown_error_regex)
+
+                # Shutdown sequence is not very well defined and as a result
+                # we might see any error given by get_shutdown_error_regex()
+                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
+                    # Run backwards, and validate we receive an error since all
+                    # other nodes are dead.
+                    dist_autograd.backward(context_id, [res.sum()])
+            else:
+                # Exit all other nodes.
+                pass
+
+    @dist_init
+    def test_backward_without_context(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+
+        context_id = 100  # dummy context_id
+        with self.assertRaisesRegex(
+            RuntimeError,
+            f"Could not find autograd context with id: {context_id}",
+        ):
+            res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+            dist_autograd.backward(context_id, [res.sum()])
+
+    @dist_init
+    def test_backward_without_rpc(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = torch.add(t1, t2)
+
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(torch.ones(3, 3), grads[t1])
+            self.assertEqual(torch.ones(3, 3), grads[t2])
+
+    @dist_init
+    def test_backward_invalid_args(self):
+        with dist_autograd.context() as context_id:
+            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
+                dist_autograd.backward(context_id, None)
+
+            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
+                dist_autograd.backward(None, None)
+
+            with self.assertRaisesRegex(
+                RuntimeError, "No tensors provided for gradient computation"
+            ):
+                dist_autograd.backward(context_id, [])
+
+            with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
+                t = torch.rand(3, 3)
+                dist_autograd.backward(context_id, [t])
+
+            with self.assertRaisesRegex(
+                RuntimeError, "is not a scalar, all roots need to be scalar"
+            ):
+                t = torch.rand(3, 3, requires_grad=True)
+                dist_autograd.backward(context_id, [t])
+
+            with self.assertRaisesRegex(
+                RuntimeError, "does not have a valid gradient function"
+            ):
+                t = torch.rand(1, requires_grad=True)
+                dist_autograd.backward(context_id, [t])
+
+    @dist_init
+    def test_backward_multiple_roots(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+            with dist_autograd.context() as context_id:
+                r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum()
+                r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum()
+                r3 = self._exec_func(exec_mode, torch.cos, t1).sum()
+                r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum()
+
+                local_grads = self._verify_backwards(
+                    exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2
+                )
+
+    @dist_init
+    def test_backward_different_dtypes(self):
+        self._backward_different_dtypes(
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_python_udf(self):
+        self._backward_simple_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_script_call(self):
+        self._backward_simple_script_call(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @staticmethod
+    def _complex_python_udf(t1, t2):
+        t3 = torch.nn.functional.linear(t1, t2)
+        t4 = torch.nn.functional.linear(t2, t3)
+        t5 = torch.nn.functional.linear(t3, t4)
+        return torch.linalg.multi_dot([t1, t2, t3, t4, t5])
+
+    @dist_init
+    def test_backward_complex_python_udf(self):
+        # Run the same code locally and with dist autograd and verify gradients
+        # are same.
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(
+                    exec_mode, DistAutogradTest._complex_python_udf, t1, t2
+                )
+                loss = ret.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    @staticmethod
+    def _python_udf_with_backward_error(t1, t2):
+        t3 = t1 + t2
+        t4 = SimulateBackwardError.apply(t3)
+        return torch.linalg.multi_dot([t1, t2, t3, t4])
+
+    @staticmethod
+    def _nested_rpc_call_backward_error(t1, t2, dst):
+        t1 = t1 * t2
+        t2 = t1 + t2
+        res = rpc.rpc_sync(
+            worker_name(dst),
+            DistAutogradTest._python_udf_with_backward_error,
+            args=(t1, t2),
+        )
+        return torch.linalg.multi_dot([t1, t2, res])
+
+    @dist_init
+    def test_backward_python_udf_error(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._nested_rpc_call_backward_error,
+                args=(t1, t2, self._next_rank()),
+            )
+            with self.assertRaisesRegex(
+                RuntimeError, "Simulate error on backward pass"
+            ):
+                dist_autograd.backward(context_id, [loss.sum()])
+
+    _backward_done = False
+
+    @dist_init(clean_shutdown=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_backward_node_failure_python_udf(self):
+        # Set a short timeout to quickly time out failed RPCs.
+        rpc._set_rpc_timeout(5)  # 5 seconds
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            dst = self._next_rank()
+            res = rpc.rpc_sync(
+                worker_name(dst),
+                my_py_nested_call,
+                args=(t1, t2, dst, self.world_size, 1),
+            )
+
+            dist.barrier()
+
+            # Kill rank 2 (last hop of nested rpc) and verify rank 0 receives an error.
+            if self.rank == 2:
+                return
+
+            store = dist.distributed_c10d._get_default_store()
+            if self.rank == 0:
+                # Wait for rank 2 to die.
+                shutdown_error_regex = self.get_shutdown_error_regex()
+                wait_until_node_failure(2, shutdown_error_regex)
+                # Shutdown sequence is not very well defined and as a result
+                # we might see any error given by get_shutdown_error_regex().
+                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
+                    # Run backwards, and validate we receive an error since rank 2 is dead.
+                    dist_autograd.backward(context_id, [res.sum()])
+
+                # Mark rank 0 is done in the store, since the RPC framework on
+                # some nodes might be broken at this point.
+                store.set("test_backward_node_failure_python_udf_rank0_done", "True")
+            else:
+                # Wait for backward to finish on rank 0.
+                store.wait(
+                    ["test_backward_node_failure_python_udf_rank0_done"],
+                    timedelta(seconds=10),
+                )
+
+    @staticmethod
+    def _nested_python_udf(t1, t2, dst):
+        t3 = t1 * t2
+        t4 = t1 + t2
+        res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4))
+        return t1 * t2 * t3 * t4 * res
+
+    @dist_init
+    def test_backwards_nested_python_udf(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    _test_clean_context_backward_context_id = None
+
+    class MyBackwardFunc(Function):
+        @staticmethod
+        def forward(ctx, input):
+            return input
+
+        @staticmethod
+        @once_differentiable
+        def backward(ctx, input):
+            assert DistAutogradTest._test_clean_context_backward_context_id is not None
+
+            # Release the context to simulate error (use barrier before releasing
+            # context to ensure all nodes execute the backward function).
+            dist.barrier()
+            dist_autograd._release_context(
+                DistAutogradTest._test_clean_context_backward_context_id
+            )
+
+            # Verify all contexts are cleaned up.
+            assert _all_contexts_cleaned_up()
+
+            return input
+
+    @dist_init
+    def test_clean_context_during_backward(self):
+        """
+        This test simulates the situation where the 'backward' call might throw
+        an exception locally which would lead to the autograd context being
+        cleaned up if we're using the context manager. As a result, the autograd
+        context might be cleaned up while some threads are still using the
+        autograd context.
+
+        It is fine for the 'backward' call to throw an exception in this test,
+        but the process should not crash.
+        """
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        context = dist_autograd._new_context()
+        context_id = context._context_id()
+        DistAutogradTest._test_clean_context_backward_context_id = context_id
+
+        # Send the context id to all nodes.
+        for i in range(self.world_size):
+            if i != self.rank:
+                rank_distance = (i - self.rank + self.world_size) % self.world_size
+                rpc.rpc_sync(
+                    worker_name(i),
+                    _set_rpc_done,
+                    args=(context_id, rank_distance),
+                )
+
+        dist.barrier()
+
+        # Verify all context ids have been received.
+        self.assertEqual(self.world_size - 1, len(known_context_ids))
+
+        t1 = torch.rand((3, 3), requires_grad=True)
+        for _ in range(100):
+            dst = self._next_rank()
+            t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))
+
+        # Call MyBackwardFunc as the first op of the backward pass to
+        # ensure we release the context early in the backward pass.
+        t1 = DistAutogradTest.MyBackwardFunc.apply(t1)
+        self.assertEqual(100, len(context._send_functions()))
+
+        context_id = 100  # dummy context_id
+        with self.assertRaisesRegex(
+            RuntimeError,
+            f"Could not find autograd context with id: {context_id}",
+        ):
+            dist_autograd.backward(context_id, [t1.sum()])
+
+        # HACK: Killing workers since otherwise the autograd engine gets stuck on
+        # other nodes. The proper fix would be addressing:
+        # https://github.com/pytorch/pytorch/issues/27643, which would inform
+        # other nodes about the failure.
+        # The autograd engine gets stuck on other nodes since they're waiting to
+        # receive gradients from the node that received an error (and as a
+        # result it didn't execute the rest of the graph).
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+        sys.exit(0)
+
+    @classmethod
+    def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights):
+        embedding = embedding_rref.local_value()
+        return embedding(input, offsets, per_sample_weights)
+
+    @classmethod
+    def _get_grad(cls, embedding_rref, context_id):
+        embedding = embedding_rref.local_value()
+        grad_map = dist_autograd.get_gradients(context_id)
+        return grad_map[embedding.weight]
+
+    @classmethod
+    def _mixed_requires_grad_operaton(cls, t1, t2):
+        if t2.requires_grad:
+            return t1 - t2
+        else:
+            return t1 * t2
+
+    @dist_init
+    def test_mixed_requires_grad(self):
+        self._mixed_requires_grad(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=False),
+            False,
+        )
+
+    class TestDebugInfoFunc(Function):
+        @staticmethod
+        def forward(ctx, input):
+            return input
+
+        @staticmethod
+        @once_differentiable
+        def backward(ctx, input):
+            debug_info = dist_autograd._get_debug_info()
+            assert debug_info is not None
+            backward_passes = int(debug_info["num_current_backward_passes"])
+
+            # Hard to validate exact numbers because of the distributed nature.
+            # We can't use a barrier() here since that would block the single
+            # CPU thread available for autograd and can cause deadlocks.
+            assert backward_passes >= 1 and backward_passes <= 4
+            return input
+
+    @dist_init
+    def test_debug_info(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            i = 0
+            res = {}
+            res[i] = t1
+            for rank in range(self.world_size):
+                if rank != self.rank:
+                    res[i + 1] = rpc.rpc_sync(
+                        worker_name(rank), torch.add, args=(res[i], t2)
+                    )
+                    i += 1
+
+            # Call custom function in middle of backward pass to ensure all
+            # nodes are still waiting on a backward().
+            res[i + 1] = DistAutogradTest.TestDebugInfoFunc.apply(res[i])
+            i += 1
+
+            for rank in range(self.world_size):
+                if rank != self.rank:
+                    res[i + 1] = rpc.rpc_sync(
+                        worker_name(rank), torch.add, args=(res[i], t2)
+                    )
+                    i += 1
+
+            dist_autograd.backward(context_id, [res[i].sum()])
+
+            debug_info = dist_autograd._get_debug_info()
+            num_autograd_context = int(debug_info["num_autograd_contexts"])
+            # Need at least one context and not more than 4.
+            self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4)
+
+        for rd in range(self.world_size - 1):
+            rpc.rpc_sync(
+                worker_name((self.rank + rd + 1) % self.world_size),
+                _set_rpc_done,
+                args=(context_id, rd + 1),
+            )
+
+        dist.barrier()
+
+        # Validate information
+        debug_info = dist_autograd._get_debug_info()
+        assert debug_info is not None
+        self.assertEqual(0, int(debug_info["num_current_backward_passes"]))
+        # only have `num_current_backward_passes` and `num_autograd contexts`
+        self.assertTrue(len(debug_info) == 2)
+
+        self.assertTrue(_all_contexts_cleaned_up())
+
+        # All contexts should be cleaned up.
+        debug_info = dist_autograd._get_debug_info()
+        self.assertEqual(0, int(debug_info["num_autograd_contexts"]))
+
+    @staticmethod
+    def _workload_thread():
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            t3 = rpc.rpc_sync("worker0", torch.add, args=(t1, t2))
+            t4 = rpc.rpc_sync("worker0", torch.mul, args=(t2, t3))
+            t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4))
+            t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5))
+
+            dist_autograd.backward(context_id, [t6.sum()])
+
+    @dist_init
+    def test_async_dist_autograd(self):
+        """
+        This test ensures async processing for distributed autograd works
+        appropriately. This is achieved by spawning multiple threads and
+        hammering a single node with a lot of backward() calls.
+        """
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        if self.rank != 0:
+            # All other ranks schedule work on rank 0.
+            threads = []
+            for _ in range(20):
+                t = threading.Thread(target=DistAutogradTest._workload_thread)
+                t.start()
+                threads.append(t)
+
+            for thread in threads:
+                thread.join()
+
+        dist.barrier()
+
+    @dist_init
+    def test_backward_accumulate_grads(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            t3 = torch.matmul(t1, t2)
+            # Run backward twice.
+            torch.autograd.backward([t3.sum()], retain_graph=True)
+            torch.autograd.backward([t3.sum()])
+
+            t3 = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
+            )
+            # Run backward twice.
+            dist_autograd.backward(context_id, [t3.sum()], retain_graph=True)
+            dist_autograd.backward(context_id, [t3.sum()])
+
+            # Verify the gradients are same for local and remote execution.
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    @staticmethod
+    def _test_nested_backward_accumulate_grads(t1, t2, dst_rank):
+        return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+
+    @dist_init
+    def test_nested_backward_accumulate_grads(self):
+        self._nested_backward_accumulate_grads(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_multiple_backward(self):
+        self._multiple_backward(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init(clean_shutdown=False)
+    def test_multiple_backward_with_errors(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                f"worker{self._next_rank()}",
+                DistAutogradTest._python_udf_with_backward_error,
+                args=(t1, t2),
+            ).sum()
+
+            try:
+                # Run backward in a loop multiple times.
+                for i in range(100):
+                    if i < 50:
+                        with self.assertRaisesRegex(
+                            RuntimeError, "Simulate error on backward pass"
+                        ):
+                            dist_autograd.backward(
+                                context_id, [loss], retain_graph=True
+                            )
+                    elif i > 50:
+                        # Recovered from error.
+                        dist_autograd.backward(context_id, [loss], retain_graph=True)
+                    else:
+                        dist.barrier()
+                        SimulateBackwardError._simulate_error = False
+                        dist.barrier()
+            finally:
+                # Sync before resetting flag.
+                dist.barrier()
+
+                # Reset the flag.
+                SimulateBackwardError._simulate_error = True
+
+    @dist_init
+    def test_backward_verify_hooks(self):
+        t1 = torch.ones((3, 3), requires_grad=True)
+        # Double the gradient.
+        t1.register_hook(lambda grad: grad * 2)
+        t2 = torch.ones((3, 3), requires_grad=True)
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(exec_mode, torch.matmul, t1, t2)
+                loss = ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_no_grad_copy(self):
+        """
+        Similar to test in test_autograd.py.
+        """
+
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp1, inp2):
+                return inp1 + inp2
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad.data_ptr()
+                return grad, grad
+
+        class MyFuncSingleGrad(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFuncSingleGrad.static_grad_ptr = grad.data_ptr()
+                return grad
+
+        class NonContGradFunc(Function):
+            @staticmethod
+            def forward(ctx, inp1):
+                ctx.size = inp1.size()
+                return torch.tensor([1.0])
+
+            @staticmethod
+            def backward(ctx, grad):
+                return torch.ones(1).expand(ctx.size)
+
+        a = torch.randn(5, 6, requires_grad=True)
+        b = torch.randn(5, 6, requires_grad=True)
+        # non-contiguous grad should be copied
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(
+                context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))]
+            )
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
+            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
+
+        # test case that should trigger no copy for a
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(context_id, [MyFuncSingleGrad.apply(a)[1][0]])
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFuncSingleGrad.static_grad_ptr
+            p_a = grads[a].data_ptr()
+            # Verify there was no clone.
+            self.assertTrue(p_a == p_g)
+
+        # Test case that should trigger copy for both of a,b. This is
+        # different in the distributed autograd case since we hold
+        # a reference to all grads in a vector until all accumulation is done.
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(context_id, [MyFunc.apply(a, b)[1][0]])
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a].data_ptr()
+            p_b = grads[b].data_ptr()
+            # check a,b uses different grad buffer
+            self.assertFalse(p_a == p_b)
+            # both should be copied.
+            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
+            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
+
+    @dist_init
+    def test_no_grad_copy_sparse(self):
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad._values().data_ptr()
+                return grad
+
+        class NonContGradFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp1, inp2):
+                return inp1 + inp2
+
+            @staticmethod
+            def backward(ctx, grad):
+                # Create a sparse tensor with non-contiguous indices and values
+                # and return as grad.
+                v = torch.rand(1, 3)
+                i = torch.ones(1, 1, dtype=torch.long)
+                nv = v.expand(8, 3)
+                ni = i.expand(1, 8)
+                ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32)
+                NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr()
+                return ngrad, ngrad
+
+        a = torch.randn(10, 3, requires_grad=True)
+        b = torch.randn(10, 3, requires_grad=True)
+        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
+        offsets = torch.tensor([0, 4])
+        import torch.nn.functional as F
+
+        # test case that should trigger no copy for a.
+        with dist_autograd.context() as context_id:
+            emb_matrix = MyFunc.apply(a)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            # check a uses the same buffer
+            self.assertTrue(p_a == p_g)
+
+            # Run backwards multiple times.
+            for _ in range(10):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+        # non-contiguous indices and value, we should trigger a copy.
+        with dist_autograd.context() as context_id:
+            emb_matrix = NonContGradFunc.apply(a, b)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = NonContGradFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            p_b = grads[b]._values().data_ptr()
+            # check a,b uses different grad buffer
+            self.assertFalse(p_a == p_b)
+            # Verify we cloned both grads.
+            self.assertFalse(p_a == p_g)
+            self.assertFalse(p_b == p_g)
+
+            # Run backwards multiple times to verify accumulation.
+            for _ in range(10):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+    @dist_init
+    def test_grad_copy_sparse_indices_extra_ref(self):
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+            static_grad_indices_ref = None
+            static_grad_values_ref = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad._values().data_ptr()
+                # indices() and values() return views, so holding onto
+                # references of them would not increment refcount of indices
+                # and values inside the sparse tensor.
+                MyFunc.static_grad_indices_ref = grad._indices()
+                MyFunc.static_grad_values_ref = grad._values()
+                return grad
+
+        a = torch.randn(10, 3, requires_grad=True)
+        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
+        offsets = torch.tensor([0, 4])
+        import torch.nn.functional as F
+
+        with dist_autograd.context() as context_id:
+            emb_matrix = MyFunc.apply(a)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            self.assertIsNotNone(MyFunc.static_grad_indices_ref)
+            self.assertIsNotNone(MyFunc.static_grad_values_ref)
+            # grad would be stolen, since static_grad_indices_ref and
+            # static_grad_values_ref are holding onto views and don't bump the
+            # refcount.
+            self.assertTrue(p_g == p_a)
+
+    @dist_init
+    def test_post_hooks(self):
+        self.hook_called_times = 0
+
+        def post_hook_add_one(output_grads, input_grads):
+            self.hook_called_times += 1
+            return output_grads
+
+        def post_hook_add_two(output_grads, input_grads):
+            self.hook_called_times += 2
+            return output_grads
+
+        t = torch.rand(10, 10, requires_grad=True)
+        a = t + t
+
+        # Register post hooks
+        accumulate_grad_0 = a.grad_fn.next_functions[0][0]
+        accumulate_grad_0.register_hook(post_hook_add_one)
+        accumulate_grad_0.register_hook(post_hook_add_two)
+
+        accumulate_grad_1 = a.grad_fn.next_functions[1][0]
+        accumulate_grad_1.register_hook(post_hook_add_two)
+
+        with dist_autograd.context() as context_id:
+            loss = a.sum()
+            dist_autograd.backward(context_id, [loss])
+            self.assertEqual(5, self.hook_called_times)
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(1, len(grads))
+            self.assertTrue(t in grads)
+
+    @staticmethod
+    def _slow_add(t1, t2):
+        time.sleep(1)
+        t3 = t1 + t2
+        t3.requires_grad = True
+        return t3
+
+    @dist_init
+    def test_thread_local_context_id(self):
+        t1 = torch.rand((3, 3))
+        t2 = torch.rand((3, 3))
+
+        t3 = t1 + t2
+        t3.requires_grad = True
+        t3.sum().backward()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2))
+
+        with dist_autograd.context() as context_id:
+            loss = rref.to_here().sum()
+            # due to slow add, the continuation of this backward pass will be
+            # invoked by the previous rpc.remote thread which does not have a
+            # valid context_id. So, this can test whether we propagate
+            # thread_local states properly when jumping across threads on the
+            # server side.
+            dist_autograd.backward(context_id, [loss])
+            self.assertTrue(
+                rpc.rpc_sync(
+                    dst, _compare_owner_value, args=(context_id, rref, t3.grad)
+                )
+            )
+
+
+class CudaDistAutogradTest(CommonDistAutogradTest):
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_simple(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        (t1 + t2).sum().backward()
+        with dist_autograd.context() as context_id:
+            t3 = t1 + t2
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_to_cpu_continuation(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True)
+        # Run a few iterations.
+        for _ in range(3):
+            t1.grad = None
+            t2.grad = None
+            # Root is CPU
+            local_grads = None
+            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+                with dist_autograd.context() as context_id:
+                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
+                    t4 = t3.cuda(0) + t1
+                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
+                    t6 = t5.cuda(0) + t4
+                    t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5)
+                    # Autograd graph consists of CPU -> GPU -> CPU execution.
+                    ret = self._verify_backwards(
+                        exec_mode, [t7.sum()], context_id, local_grads, t1, t2
+                    )
+                    local_grads = ret if ret else local_grads
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_to_cpu_continuation_gpu_root(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True)
+        # Run a few iterations.
+        for _ in range(3):
+            t1.grad = None
+            t2.grad = None
+            # Root is CPU
+            local_grads = None
+            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+                with dist_autograd.context() as context_id:
+                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
+                    t4 = t3.cuda(0) + t1
+                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
+                    t6 = t5.cuda(0) + t4
+                    # Autograd graph consists of CPU -> GPU -> CPU execution.
+                    ret = self._verify_backwards(
+                        exec_mode, [t6.sum()], context_id, local_grads, t1, t2
+                    )
+                    local_grads = ret if ret else local_grads
+
+
+class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
+    # Reusing a simplified helper function from DistAutogradTest to ensure
+    # autograd context is successfully cleaned up even when RPCs are failing.
+    def context_cleanup_test_helper(self, rpc_args, func):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # test that in dist autograd, in the case that tensors communicated over RPC do
+        # NOT require grad, we still cleanup the dist autograd contexts created
+        # on other nodes. This is because the autograd context is still
+        # communicated over RPC even if tensor arguments do not require grad, as
+        # it is possible that the response could.
+        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+
+        with dist_autograd.context() as context_id:
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+        # the thread's context id should be cleaned up
+        with self.assertRaises(RuntimeError):
+            dist_autograd._retrieve_context(context_id)
+        # Ensure all peers have finished mutating the
+        # `known_context_ids` set.
+        dist.barrier()
+        # check that all contexts have been cleaned up.
+        success = _all_contexts_cleaned_up()
+        self.assertTrue(success)
+
+    # no faulty_messages defined so this fails all retryable messages - see
+    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
+    @dist_init
+    def test_context_cleanup_tensor_with_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_verify_backend_options(self):
+        self.assertEqual(
+            self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
+        )
+        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
+        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
+
+
+class WrapperModule(nn.Module):
+    def __init__(self, model, device):
+        super().__init__()
+        self.model = model.to(device)
+
+    def forward(self, *args):
+        return self.model(*args)
+
+    def gradients(self, ctx_id):
+        grads = dist_autograd.get_gradients(ctx_id)
+        return [grads[p] for p in self.model.parameters()]
+
+
+class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_backward_pass(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        # The reverse of this device mapping should be used for the backward pass.
+        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        t1 = torch.rand(10, device=self.rank, requires_grad=True)
+        t2 = torch.rand(10, device=self.rank, requires_grad=True)
+        with dist_autograd.context() as context_id:
+            res = rpc.rpc_sync(dst, torch.add, args=(t1, t2))
+            dist_autograd.backward(context_id, [res.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(torch.ones(10), grads[t1])
+            self.assertEqual(torch.ones(10), grads[t2])
+            self.assertEqual(t1.device, grads[t1].device)
+            self.assertEqual(t2.device, grads[t2].device)
+
+        rpc.shutdown()
+
+    class MyRemoteCompute(torch.nn.Module):
+        def forward(self, input):
+            input = input * 2.0
+            return input
+
+    class MyLocalCompute(torch.nn.Module):
+        def __init__(self, next_stage):
+            super().__init__()
+            self.next_stage = next_stage
+
+        def forward(self, input):
+            return self.next_stage.rpc_sync().forward(input)
+
+    @skip_if_lt_x_gpu(4)
+    def test_dist_autograd_sync_streams(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        # The reverse of this device mapping should be used for the backward pass.
+        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute)
+        local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute)
+        for _ in range(10):
+            input = torch.rand([1000, 10000], device=self.rank, requires_grad=True)
+            # Run local autograd
+            result = input * 2.0
+            r = random.random()
+            loss = result.sum() * r
+            loss.backward()
+
+            # Run distributed autograd
+            with dist_autograd.context() as context_id:
+                result = local_compute(input)
+                loss = result.sum() * r
+                dist_autograd.backward(context_id, [loss])
+
+                # Compare grads.
+                grads = dist_autograd.get_gradients(context_id)
+                self.assertEqual(input.grad, grads[input])
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(4)
+    def test_gradients_synchronizations(self):
+        options = self.rpc_backend_options
+        for peer_rank in range(self.world_size):
+            options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 0:
+            # this is master
+            layers = [nn.Linear(2000, 2000) for _ in range(self.world_size - 1)]
+            local_layers = [l.to(0) for l in layers]
+            remote_layers = [
+                rpc.remote(
+                    worker_name(rank), WrapperModule, args=(layers[rank - 1], rank)
+                )
+                for rank in range(1, self.world_size)
+            ]
+
+            x = torch.randn(5000, 2000).to(0)
+            # local iteration
+            local_model = nn.Sequential(*local_layers)
+            local_model(x).sum().backward()
+
+            # remote iteration
+            with dist_autograd.context() as context_id:
+                for remote_layer in remote_layers:
+                    x = remote_layer.rpc_sync().forward(x)
+
+                dist_autograd.backward(context_id, [x.sum()])
+
+                futs = []
+                for remote_layer in remote_layers:
+                    futs.append(remote_layer.rpc_async().gradients(context_id))
+
+                for i in range(len(futs)):
+                    local_gradients = [p.grad for p in local_layers[i].parameters()]
+                    for g1, g2 in zip(futs[i].wait(), local_gradients, strict=True):
+                        self.assertEqual(g1, g2)
+
+        rpc.shutdown()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d335325f8364241dd14517da5c67c2a6e6a032b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
@@ -0,0 +1,281 @@
+# mypy: allow-untyped-defs
+
+
+import threading
+
+import torch
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+from torch import optim
+from torch.distributed.optim import DistributedOptimizer
+from torch.testing._internal.dist_utils import dist_init
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+class MyModule:
+    lock = threading.Lock()
+
+    def __init__(self, requires_grad=True):
+        # cannot directly use torch.manual_seed(0) as all threads share the same
+        # default generator. The race from multiple RPC threads could mess up
+        # the draw order from the default RNG instance, leading to
+        # non-deterministic behavior. Hence, create a dedicated RNG here.
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        self.w = torch.rand((3, 3), requires_grad=requires_grad, generator=g_cpu)
+
+    def forward(self, t1):
+        return torch.mm(self.w, t1)
+
+    def get_w(self):
+        return self.w
+
+
+class FailingOptimizer(optim.Optimizer):
+    def __init__(self, params):
+        super().__init__(params, {})
+
+    def step(self, closure=None):
+        raise ValueError("Error running optimizer.")
+
+
+class OptimizerFailingOnConstructor(optim.Optimizer):
+    def __init__(self, params):
+        super().__init__(params, {})
+        raise ValueError("Error creating optimizer.")
+
+    def step(self, closure=None):
+        raise NotImplementedError
+
+
+def _call_method(method, obj_rref, *args, **kwargs):
+    return method(obj_rref.local_value(), *args, **kwargs)
+
+
+def remote_method(method, obj_rref, *args, **kwargs):
+    """
+    Call rpc.remote on a method in a remote object.
+
+    Args:
+        method: the method (for example, Class.method)
+        obj_rref (RRef): remote reference to the object
+        args: positional arguments to pass to the method
+        kwargs: keyword arguments to pass to the method
+
+    Returns a RRef to the remote method call result.
+    """
+    return rpc.remote(
+        obj_rref.owner(),
+        _call_method,
+        args=[method, obj_rref] + list(args),
+        kwargs=kwargs,
+    )
+
+
+def rpc_async_method(method, obj_rref, *args, **kwargs):
+    """
+    Call rpc.rpc_async on a method in a remote object.
+
+    Args:
+        method: the method (for example, Class.method)
+        obj_rref (RRef): remote reference to the object
+        args: positional arguments to pass to the method
+        kwargs: keyword arguments to pass to the method
+
+    Returns a Future to the method call result.
+    """
+    return rpc.rpc_async(
+        obj_rref.owner(),
+        _call_method,
+        args=[method, obj_rref] + list(args),
+        kwargs=kwargs,
+    )
+
+
+class DistOptimizerTest(RpcAgentTestFixture):
+    @dist_init()
+    def test_dist_optim_exception(self):
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        dist_optim = DistributedOptimizer(
+            FailingOptimizer, [remote_param1, remote_param2]
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu = torch.Generator()
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
+            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
+            loss = torch.add(output2.wait(), t1).sum()
+
+            dist_autograd.backward(context_id, [loss])
+            with self.assertRaisesRegex(Exception, "Error running optimizer"):
+                dist_optim.step(context_id)
+
+    @dist_init()
+    def test_dist_optim_exception_on_constructor(self):
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        with self.assertRaisesRegex(Exception, "Error creating optimizer."):
+            DistributedOptimizer(
+                OptimizerFailingOnConstructor, [remote_param1, remote_param2]
+            )
+
+    def _test_dist_optim_base(self, optim_cls, *args, **kwargs):
+        # local version
+        module1 = MyModule()
+        module2 = MyModule()
+        params = [module1.get_w(), module2.get_w()]
+        local_optim = optim_cls(params, *args, **kwargs)
+
+        old_w1 = module1.w.detach().clone()
+        old_w2 = module2.w.detach().clone()
+
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        output1 = module1.forward(t2)
+        output2 = module2.forward(output1)
+        loss = torch.add(output2, t1).sum()
+
+        loss.backward()
+        local_optim.step()
+
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        # sanity check: local and remote initial weights should match
+        self.assertEqual(old_w1, remote_param1.to_here())
+        self.assertEqual(old_w2, remote_param2.to_here())
+
+        dist_optim = DistributedOptimizer(
+            optim_cls, [remote_param1, remote_param2], *args, **kwargs
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
+            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
+            loss = torch.add(output2.wait(), t1)
+
+            dist_autograd.backward(context_id, [loss.sum()])
+            dist_optim.step(context_id)
+
+            new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
+            new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
+
+            # ensure optimizer changed weights
+            self.assertNotEqual(old_w1, new_w1)
+            self.assertNotEqual(old_w2, new_w2)
+            # ensure local equals remote
+            self.assertEqual(new_w1, module1.get_w())
+            self.assertEqual(new_w2, module2.get_w())
+
+    @dist_init()
+    def test_dist_optim(self):
+        self._test_dist_optim_base(optim.Adagrad, lr=0.05)
+        self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
+        self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
+        self._test_dist_optim_base(optim.SGD, lr=0.05)
+        self._test_dist_optim_base(
+            optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True
+        )
+        self._test_dist_optim_base(optim.Adadelta, rho=0.95)
+        self._test_dist_optim_base(optim.RMSprop, lr=0.05)
+        self._test_dist_optim_base(optim.Adamax, lr=0.05)
+        self._test_dist_optim_base(optim.Rprop, lr=0.05)
+
+    def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs):
+        # local version
+        module1 = MyModule()
+        module2 = MyModule(requires_grad=False)
+        params = [module1.get_w(), module2.get_w()]
+        local_optim = optim_cls(params, *args, **kwargs)
+
+        old_w1 = module1.w.detach().clone()
+        old_w2 = module2.w.detach().clone()
+
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        output1 = module1.forward(t2)
+        output2 = module2.forward(output1)
+        loss = torch.add(output2, t1).sum()
+
+        loss.backward()
+        local_optim.step()
+
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule, args=(False,))
+        remote_param1 = remote_module1.remote().get_w()
+        remote_param2 = remote_module2.remote().get_w()
+
+        # sanity check: local and remote initial weights should match
+        self.assertEqual(old_w1, remote_param1.to_here())
+        self.assertEqual(old_w2, remote_param2.to_here())
+
+        dist_optim = DistributedOptimizer(
+            optim_cls, [remote_param1, remote_param2], *args, **kwargs
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = remote_module1.rpc_async().forward(t2)
+            output2 = remote_module2.rpc_async().forward(output1.wait())
+            loss = torch.add(output2.wait(), t1)
+
+            dist_autograd.backward(context_id, [loss.sum()])
+            dist_optim.step(context_id)
+
+            new_w1 = remote_module1.rpc_async().get_w().wait()
+            new_w2 = remote_module2.rpc_async().get_w().wait()
+
+            # ensure optimizer changed weights for w1
+            self.assertNotEqual(old_w1, new_w1)
+
+            # ensure optimizer not changed weights for w2
+            self.assertEqual(old_w2, new_w2)
+            # ensure local equals remote
+            self.assertEqual(new_w1, module1.get_w())
+            self.assertEqual(new_w2, module2.get_w())
+
+    @dist_init()
+    def test_dist_optim_none_grads(self):
+        self._test_dist_optim_none_grads(optim.SGD, lr=0.05)
+        self._test_dist_optim_none_grads(optim.RMSprop, lr=0.05)
+        self._test_dist_optim_none_grads(optim.Rprop, lr=0.05)
+        self._test_dist_optim_none_grads(optim.Adadelta, rho=0.95)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..168bc84542ce854d5e8259f30d292d3b59ee0908
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d26552ffdbb8c1ee77fa6cea4bc8d79a4ae24fb4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9220d1c6499b9baa55dfda738e74acf9ec9191bd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad0b7fbe2207f8533da1eba8c23cda513f2bcf25
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
@@ -0,0 +1,140 @@
+# mypy: allow-untyped-defs
+
+# If you need to modify this file to make this test pass, please also apply same edits accordingly to
+# https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py
+# and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server
+
+import threading
+from datetime import datetime
+from time import perf_counter
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+from torch import optim
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+batch_size = 20
+in_features = 100
+out_features = 30
+num_batches = 4
+
+
+def timed_log(text):
+    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")
+
+
+class BatchUpdateParameterServer:
+    def __init__(self, batch_update_size):
+        self.model = nn.Linear(in_features, out_features)
+        self.lock = threading.Lock()
+        self.future_model = torch.futures.Future()
+        self.batch_update_size = batch_update_size
+        self.curr_update_size = 0
+        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
+        for p in self.model.parameters():
+            p.grad = torch.zeros_like(p)
+
+    def get_model(self):
+        return self.model
+
+    @staticmethod
+    @rpc.functions.async_execution
+    def update_and_fetch_model(ps_rref, grads):
+        self = ps_rref.local_value()
+        for p, g in zip(self.model.parameters(), grads, strict=True):
+            if p.grad is None:
+                p.grad = g
+            else:
+                p.grad += g
+        with self.lock:
+            timed_log(
+                f"PS got {self.curr_update_size}/{self.batch_update_size} updates"
+            )
+            self.curr_update_size += 1
+            fut = self.future_model
+
+            if self.curr_update_size >= self.batch_update_size:
+                for p in self.model.parameters():
+                    p.grad /= self.batch_update_size
+                self.curr_update_size = 0
+                self.optimizer.step()
+                self.optimizer.zero_grad()
+                fut.set_result(self.model)
+                timed_log("PS updated model")
+                self.future_model = torch.futures.Future()
+
+        return fut
+
+
+class Trainer:
+    def __init__(self, ps_rref):
+        self.ps_rref = ps_rref
+        self.loss_fn = nn.L1Loss()
+
+    def get_next_batch(self):
+        for _ in range(num_batches):
+            inputs = torch.randn(batch_size, in_features)
+            labels = torch.zeros(batch_size, out_features)
+            yield inputs, labels
+
+    def train(self):
+        name = rpc.get_worker_info().name
+        m = self.ps_rref.rpc_sync().get_model()
+        for inputs, labels in self.get_next_batch():
+            timed_log(f"{name} processing one batch")
+            self.loss_fn(m(inputs), labels).backward()
+            timed_log(f"{name} reporting grads")
+            m = rpc.rpc_sync(
+                self.ps_rref.owner(),
+                BatchUpdateParameterServer.update_and_fetch_model,
+                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
+            )
+            timed_log(f"{name} got updated model")
+
+
+def run_trainer(ps_rref):
+    trainer = Trainer(ps_rref)
+    trainer.train()
+
+
+def run_ps(trainers):
+    timed_log("Start training")
+    start = perf_counter()
+    ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers)))
+    futs = [
+        rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers
+    ]
+
+    torch.futures.wait_all(futs)
+    stop = perf_counter()
+    timed_log("Finish training")
+    timed_log(f"Time spent training: {stop - start}s")
+
+
+class ParameterServerTest(RpcAgentTestFixture):
+    @dist_init(setup_rpc=False)
+    def test_batch_updating_parameter_server(self):
+        if self.rank != 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        else:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)])
+
+        rpc.shutdown()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..57008aed17dba34aacbc3b8a7a5b62c6dcbb5526
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py
@@ -0,0 +1,265 @@
+# mypy: allow-untyped-defs
+
+# If you need to modify this file to make this test pass, please also apply same edits accordingly to
+# https://github.com/pytorch/examples/blob/master/distributed/rpc/rl/main.py
+# and https://pytorch.org/tutorials/intermediate/rpc_tutorial.html
+
+import numpy as np
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.distributed.rpc import remote, rpc_async, rpc_sync, RRef
+from torch.distributions import Categorical
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+TOTAL_EPISODE_STEP = 5000
+GAMMA = 0.1
+SEED = 543
+
+
+def _call_method(method, rref, *args, **kwargs):
+    r"""
+    a helper function to call a method on the given RRef
+    """
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def _remote_method(method, rref, *args, **kwargs):
+    r"""
+    a helper function to run method on the owner of rref and fetch back the
+    result using RPC
+    """
+    args = [method, rref] + list(args)
+    return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)
+
+
+class Policy(nn.Module):
+    r"""
+    Borrowing the ``Policy`` class from the Reinforcement Learning example.
+    Copying the code to make these two examples independent.
+    See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.affine1 = nn.Linear(4, 128)
+        self.dropout = nn.Dropout(p=0.6)
+        self.affine2 = nn.Linear(128, 2)
+
+        self.saved_log_probs = []
+        self.rewards = []
+
+    def forward(self, x):
+        x = self.affine1(x)
+        x = self.dropout(x)
+        x = F.relu(x)
+        action_scores = self.affine2(x)
+        return F.softmax(action_scores, dim=1)
+
+
+class DummyEnv:
+    r"""
+    A dummy environment that implements the required subset of the OpenAI gym
+    interface. It exists only to avoid a dependency on gym for running the
+    tests in this file. It is designed to run for a set max number of iterations,
+    returning random states and rewards at each step.
+    """
+
+    def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0):
+        self.state_dim = state_dim
+        self.num_iters = num_iters
+        self.iter = 0
+        self.reward_threshold = reward_threshold
+
+    def seed(self, manual_seed):
+        torch.manual_seed(manual_seed)
+
+    def reset(self):
+        self.iter = 0
+        return torch.randn(self.state_dim)
+
+    def step(self, action):
+        self.iter += 1
+        state = torch.randn(self.state_dim)
+        reward = torch.rand(1).item() * self.reward_threshold
+        done = self.iter >= self.num_iters
+        info = {}
+        return state, reward, done, info
+
+
+class Observer:
+    r"""
+    An observer has exclusive access to its own environment. Each observer
+    captures the state from its environment, and send the state to the agent to
+    select an action. Then, the observer applies the action to its environment
+    and reports the reward to the agent.
+    """
+
+    def __init__(self) -> None:
+        self.id = rpc.get_worker_info().id
+        self.env = DummyEnv()
+        self.env.seed(SEED)
+
+    def run_episode(self, agent_rref, n_steps):
+        r"""
+        Run one episode of n_steps.
+        Arguments:
+            agent_rref (RRef): an RRef referencing the agent object.
+            n_steps (int): number of steps in this episode
+        """
+        state, _ep_reward = self.env.reset(), 0
+        for _ in range(n_steps):
+            # send the state to the agent to get an action
+            action = _remote_method(Agent.select_action, agent_rref, self.id, state)
+
+            # apply the action to the environment, and get the reward
+            state, reward, done, _ = self.env.step(action)
+
+            # report the reward to the agent for training purpose
+            _remote_method(Agent.report_reward, agent_rref, self.id, reward)
+
+            if done:
+                break
+
+
+class Agent:
+    def __init__(self, world_size):
+        self.ob_rrefs = []
+        self.agent_rref = RRef(self)
+        self.rewards = {}
+        self.saved_log_probs = {}
+        self.policy = Policy()
+        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
+        self.eps = np.finfo(np.float32).eps.item()
+        self.running_reward = 0
+        self.reward_threshold = DummyEnv().reward_threshold
+        for ob_rank in range(1, world_size):
+            ob_info = rpc.get_worker_info(worker_name(ob_rank))
+            self.ob_rrefs.append(remote(ob_info, Observer))
+            self.rewards[ob_info.id] = []
+            self.saved_log_probs[ob_info.id] = []
+
+    def select_action(self, ob_id, state):
+        r"""
+        This function is mostly borrowed from the Reinforcement Learning example.
+        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+        The main difference is that instead of keeping all probs in one list,
+        the agent keeps probs in a dictionary, one key per observer.
+
+        NB: no need to enforce thread-safety here as GIL will serialize
+        executions.
+        """
+        probs = self.policy(state.unsqueeze(0))
+        m = Categorical(probs)
+        action = m.sample()
+        self.saved_log_probs[ob_id].append(m.log_prob(action))
+        return action.item()
+
+    def report_reward(self, ob_id, reward):
+        r"""
+        Observers call this function to report rewards.
+        """
+        self.rewards[ob_id].append(reward)
+
+    def run_episode(self, n_steps=0):
+        r"""
+        Run one episode. The agent will tell each observer to run n_steps.
+        """
+        # make async RPC to kick off an episode on all observers
+        futs = [
+            rpc_async(
+                ob_rref.owner(),
+                _call_method,
+                args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps),
+            )
+            for ob_rref in self.ob_rrefs
+        ]
+
+        # wait until all observers have finished this episode
+        for fut in futs:
+            fut.wait()
+
+    def finish_episode(self):
+        r"""
+        This function is mostly borrowed from the Reinforcement Learning example.
+        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+        The main difference is that it joins all probs and rewards from
+        different observers into one list, and uses the minimum observer rewards
+        as the reward of the current episode.
+        """
+
+        # joins probs and rewards from different observers into lists
+        R, probs, rewards = 0, [], []
+        for ob_id in self.rewards:
+            probs.extend(self.saved_log_probs[ob_id])
+            rewards.extend(self.rewards[ob_id])
+
+        # use the minimum observer reward to calculate the running reward
+        min_reward = min(sum(self.rewards[ob_id]) for ob_id in self.rewards)
+        self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
+
+        # clear saved probs and rewards
+        for ob_id in self.rewards:
+            self.rewards[ob_id] = []
+            self.saved_log_probs[ob_id] = []
+
+        policy_loss, returns = [], []
+        for r in rewards[::-1]:
+            R = r + GAMMA * R
+            returns.insert(0, R)
+        returns = torch.tensor(returns)
+        returns = (returns - returns.mean()) / (returns.std() + self.eps)
+        for log_prob, R in zip(probs, returns, strict=True):
+            policy_loss.append(-log_prob * R)
+        self.optimizer.zero_grad()
+        policy_loss = torch.cat(policy_loss).sum()
+        policy_loss.backward()
+        self.optimizer.step()
+        return min_reward
+
+
+def run_agent(agent, n_steps):
+    while True:
+        agent.run_episode(n_steps=n_steps)
+        agent.finish_episode()
+
+        if agent.running_reward > agent.reward_threshold:
+            print(f"Solved! Running reward is now {agent.running_reward}!")
+            break
+
+
+class ReinforcementLearningRpcTest(RpcAgentTestFixture):
+    @dist_init(setup_rpc=False)
+    def test_rl_rpc(self):
+        if self.rank == 0:
+            # Rank 0 is the agent.
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            agent = Agent(self.world_size)
+            run_agent(agent, n_steps=int(TOTAL_EPISODE_STEP / (self.world_size - 1)))
+
+            # Ensure training was run. We don't really care about whether the task was learned,
+            # since the purpose of the test is to check the API calls.
+            self.assertGreater(agent.running_reward, 0.0)
+        else:
+            # Other ranks are observers that passively wait for instructions from the agent.
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        rpc.shutdown()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..747155e3e1cbce8f8e8c14756fe3f98bf22a8987
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py
@@ -0,0 +1,337 @@
+# mypy: allow-untyped-defs
+
+import time
+
+import torch
+import torch.distributed.rpc as rpc
+from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    wait_until_owners_and_forks_on_rank,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def my_sleep_func(seconds=1):
+    time.sleep(seconds)
+    return torch.mul(torch.tensor(1), torch.tensor(1))
+
+
+@torch.jit.script
+def my_script_func(tensor):
+    return torch.add(tensor, tensor)
+
+
+def add_rref_to_value(rref, value):
+    return rref.to_here() + value
+
+
+class FaultyAgentRpcTest(RpcAgentTestFixture):
+    # no faulty_messages defined so this fails all retryable messages - see
+    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
+    @dist_init(messages_to_delay={})
+    def test_check_failed_messages(self):
+        if self.rank == 0:
+            dst_worker_b = worker_name((self.rank + 1) % self.world_size)
+            dst_worker_c = worker_name((self.rank + 2) % self.world_size)
+
+            # Worker0 sends RPC to Worker1 and creates an RRef there
+            rref = rpc.remote(
+                dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))
+            )
+            # Worker0 sends an RPC to Worker2 with the RRef as an arg
+            rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2)))
+            # check if the output is as expected
+            self.assertEqual(
+                rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))
+            )
+        # explicitly delete all User RRefs
+        _delete_all_user_and_unforked_owner_rrefs()
+
+    @dist_init
+    def test_verify_backend_options(self):
+        self.assertEqual(
+            self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
+        )
+        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
+        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2)
+        self.assertEqual(
+            self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC
+        )
+
+    @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"])
+    def test_custom_faulty_messages(self):
+        self.assertEqual(
+            {"RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"},
+            set(self.rpc_backend_options.messages_to_fail),
+        )
+
+    @dist_init(faulty_messages=[])
+    def test_no_faulty_messages(self):
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0)
+
+    @dist_init(messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_custom_messages_to_delay(self):
+        self.assertEqual(
+            self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5}
+        )
+
+    def _test_remote_message_dropped_pickle(self, dst=None):
+        if self.rank != 0:
+            return
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # Since we fail python_remote_call messages synchronously, the future
+        # corresponding to this remote call will be marked with an error when
+        # this function returns.
+        rref = rpc.remote(dst_worker, my_sleep_func, args=(1,))
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref._serialize()
+        # Test that using RRef as arg over RPC (which forks) results in the same
+        # error
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1))
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_remote_message_dropped_pickle(self):
+        self._test_remote_message_dropped_pickle()
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_remote_message_dropped_pickle_to_self(self):
+        self._test_remote_message_dropped_pickle(self.rank)
+
+    def _test_remote_message_dropped_timeout(self, func, args, dst=None):
+        if self.rank != 0:
+            return
+
+        # test the case where rpc.remote() message creation is completely dropped.
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # Since we fail python_remote_call messages synchronously, the future
+        # corresponding to this remote call will be marked with an error when
+        # this function returns.
+        rref = rpc.remote(dst_worker, func, args=args)
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+        # Note: during shutdown, logs will indicate "Could not find OwnerRRef..."
+        # on the owning nodes, this is expected because the OwnerRRef was never
+        # successfully created. Therefore, delAllUsers will work as expected.
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_builtin_remote_message_dropped_timeout(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_dropped_timeout(func, args)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_builtin_remote_message_dropped_timeout_to_self(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_dropped_timeout(func, args, dst=0)
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_udf_remote_message_dropped_timeout(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_dropped_timeout(func, args)
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_udf_remote_message_dropped_timeout_to_self(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_dropped_timeout(func, args, dst=0)
+
+    def _test_remote_message_delay_timeout(self, func, args, dst=None):
+        if self.rank != 0:
+            return
+        # Test the case where remote message is eventually processed on the owner,
+        # but the future on the creator times out before the response comes back.
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # 10 ms timeout
+        rref = rpc.remote(dst_worker, func, args=args, timeout=0.001)
+        # Future corresponding to the remote creation should time out.
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref._get_future().wait()
+
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # to_here() should now pick up that rpc.remote() creation has failed.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+
+        # Test the case where rpc.remote() times out, but to_here() has already
+        # started blocking before.
+        # NOTE: we only test this when not sending to self, as to_here() calls
+        # calls localValue(), which does not send an RPC and thus does not have
+        # a timeout. This can be supported by allowing future.wait() to
+        # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280)
+        if dst_rank != self.rank:
+            slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2)
+
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                # to_here() should raise timeout error, since it does not know about the
+                # status of rpc.remote().
+                slow_rref.to_here(0.001)
+        # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete
+        # but this can be a noop since it may not exist on the owner yet. Later,
+        # the owner can process the RRef creation and wait for the delete message,
+        # thus leading to a timeout.
+        # Therefore, we wait until we get notification that pending owners have
+        # been confirmed before sending out RRefUserDeletes.
+        if dst_rank != self.rank:
+            wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
+    def test_udf_remote_message_delay_timeout(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
+    def test_udf_remote_message_delay_timeout_to_self(self):
+        func = my_sleep_func
+        args = (1,)
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_builtin_delay_timeout(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_builtin_delay_timeout_to_self(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_script_delay_timeout(self):
+        func = my_script_func
+        args = (torch.tensor(1),)
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_script_delay_timeout_to_self(self):
+        func = my_script_func
+        args = (torch.tensor(1),)
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
+    def test_rref_to_here_timeout(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref.to_here(0.01)
+
+        rref.to_here()
+
+    @dist_init(faulty_messages=[])
+    def test_rpc_builtin_timeout(self):
+        next_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(next_rank)
+        expected_error = self.get_timeout_error_regex()
+        # PYTHON_CALL message types which correspond to Python UDF over RPC
+        # by default get a delay (see faulty_rpc_agent_test_fixture)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(
+                dst_worker,
+                torch.add,
+                args=(torch.tensor(1), torch.tensor(1)),
+                timeout=1,
+            )
+
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure that the currently set default timeout is large enough such
+        # that RPCs with delays still complete.
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        fut.wait()
+
+        # Ensure timeout if we set a new default and don't override
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if we specify timeout of 0
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0
+        )
+        fut.wait()
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_rpc_script_timeout(self):
+        next_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(next_rank)
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
+
+        fut = rpc.rpc_async(
+            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure that the currently set default timeout is large enough such
+        # that RPCs with delays still complete.
+        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
+        fut.wait()
+
+        # Ensure timeout if we set a new default and don't override
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if we specify timeout of 0
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(
+            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0
+        )
+        fut.wait()
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff7d556d10621e7290c07ecb433b865d7133bb2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py
@@ -0,0 +1,64 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed.rpc as rpc
+import torch.distributed.rpc._testing  # noqa: F401
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+# The following message types are currently retried in the RREF protocol and
+# distributed autograd. Thus only these messages should be tested with the
+# Faulty RPC Agent.
+retryable_message_types = [
+    "RREF_FORK_REQUEST",
+    "RREF_CHILD_ACCEPT",
+    "RREF_USER_DELETE",
+    "CLEANUP_AUTOGRAD_CONTEXT_REQ",
+]
+
+# The following messages incur the corresponding delay in seconds while being
+# processed in FaultyTensorPipeAgent's enqueueSend() function.
+default_messages_to_delay = {
+    "PYTHON_CALL": 1.5,  # Python UDF
+    "SCRIPT_CALL": 1.5,  # Script/Builtin
+}
+
+
+class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.messages_to_fail = retryable_message_types
+        self.messages_to_delay = default_messages_to_delay
+
+    @property
+    def rpc_backend(self):
+        return rpc.backend_registry.BackendType["FAULTY_TENSORPIPE"]
+
+    @property
+    def rpc_backend_options(self):
+        return rpc.backend_registry.construct_rpc_backend_options(
+            self.rpc_backend,
+            init_method=self.init_method,
+            num_worker_threads=8,
+            num_fail_sends=3,
+            messages_to_fail=self.messages_to_fail,
+            messages_to_delay=self.messages_to_delay,
+        )
+
+    def setup_fault_injection(self, faulty_messages, messages_to_delay):
+        if faulty_messages is not None:
+            self.messages_to_fail = faulty_messages
+        if messages_to_delay is not None:
+            self.messages_to_delay = messages_to_delay
+
+    def get_shutdown_error_regex(self):
+        error_regexes = [
+            "Exception in thread pool task",
+            "Connection reset by peer",
+            "Connection closed by peer",
+        ]
+        return "|".join([f"({error_str})" for error_str in error_regexes])
+
+    def get_timeout_error_regex(self):
+        return "RPC ran for more than"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac2c36a2b89dbc2e49464110f8dfe21e5be65792
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cadb08e070060860461243ca9ef98fccdb828e18
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a802c7599e0fdc8d8ada79e9c72b3b7dafe35d3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d454970e3802e397af0201842439f8d9921d1bc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde1fe2355c2968e1b351b288d20c674835b0ca2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py
@@ -0,0 +1,113 @@
+# mypy: allow-untyped-defs
+
+
+import torch
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.distributed.rpc import rpc_async
+from torch.testing import FileCheck
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+@torch.jit.script
+def local_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+@torch.jit.script
+def remote_add(t1, t2, dst: str):  # noqa: E999
+    return rpc_async(dst, local_add, (t1, t2)).wait()
+
+
+@torch.jit.script
+def fork_add(t1, t2, dst: str):
+    fut = torch.jit._fork(remote_add, t1, t2, dst)
+    return torch.jit._wait(fut)
+
+
+class JitDistAutogradTest(RpcAgentTestFixture):
+    @dist_init
+    def test_get_gradients(self):
+        @torch.jit.script
+        def dist_get_gradients(context_id: int) -> dict[Tensor, Tensor]:
+            return dist_autograd.get_gradients(context_id)
+
+        FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = torch.add(t1, t2)
+
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_get_gradients(context_id)
+
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(torch.ones(3, 3), grads[t1])
+            self.assertEqual(torch.ones(3, 3), grads[t2])
+
+    @dist_init
+    def test_dist_backward(self):
+        if self.rank != 0:
+            return
+
+        @torch.jit.script
+        def dist_backward_script(context_id: int, loss: torch.Tensor):
+            dist_autograd.backward(context_id, [loss])
+
+        FileCheck().check("dist_backward").run(str(dist_backward_script.graph))
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(3, 3, requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum()
+            dist_backward_script(context_id, loss)
+
+    @dist_init
+    def test_jit_fork_within_context(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            res = fork_add(t1, t2, dst_worker_name)
+            loss = res.sum()
+            dist_autograd.backward(context_id, [loss])
+
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+
+    @dist_init
+    def test_restore_context_after_swtich_to_jit_thread(self):
+        if self.rank != 0:
+            return
+
+        @torch.jit.script
+        def forward_script(
+            context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor
+        ) -> tuple[Tensor, Tensor]:
+            res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1))
+            res1 = res1_fut.wait()  # After this, the script runs in a new JIT thread.
+            loss1 = res1.sum()
+
+            # SendRpcBackward is not attached, since DistAutogradContext is lost here.
+            res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2))
+            res2 = res2_fut.wait()
+            loss2 = res2.sum()
+
+            return loss1, loss2
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.ones((2, 3), requires_grad=True)
+            t2 = torch.ones((2, 3), requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2)
+            dist_autograd.backward(context_id, [loss0, loss1])
+            grad0, grad1 = dist_autograd.get_gradients(context_id)
+            self.assertEqual(grad0, grad1)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..82a5d66e87f38672fe7076075b764a094bb81b4c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -0,0 +1,1384 @@
+# mypy: allow-untyped-defs
+
+import io
+import time
+from typing import Any
+
+import torch
+import torch.distributed as dist
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.autograd.profiler import record_function
+from torch.autograd.profiler_legacy import profile as _profile
+from torch.distributed.rpc import RRef
+from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode
+from torch.futures import Future
+from torch.testing._internal.common_utils import TemporaryFileName
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    get_function_event,
+    initialize_pg,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def rref_isinstance(rref, cls_to_check):
+    return isinstance(rref.local_value(), cls_to_check)
+
+
+def sleep(t):
+    time.sleep(t)
+
+
+def rpc_return_rref(dst):
+    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+@torch.jit.script
+def rref_local_value(rref: RRef[Tensor]) -> Tensor:
+    return rref.local_value()
+
+
+@torch.jit.script
+def list_create() -> list[int]:
+    global_list = [1, 2, 3]
+    return global_list
+
+
+@torch.jit.script
+def rref_list_mutate(rref: RRef[list[int]]) -> None:
+    rref.local_value().append(4)
+    rref.to_here().append(5)
+    rref.to_here(5.0).append(6)
+
+
+def return_value(value: int) -> int:
+    return value
+
+
+class RRefAPITest:
+    @dist_init
+    def test_rref_is_owner(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        rref_var = rpc_return_rref(dst_worker_name)
+
+        @torch.jit.script
+        def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool:
+            return rref_var.is_owner()
+
+        res = rref_tensor_is_owner(rref_var)
+        self.assertEqual(res, False)
+
+    @dist_init
+    def test_rref_local_value(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc_return_rref(dst_worker_name)
+
+        with self.assertRaisesRegex(
+            RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef"
+        ):
+            rref_local_value(rref)
+
+        ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,))
+        self.assertEqual(ret, torch.add(torch.ones(2, 2), 1))
+
+    @dist_init
+    def test_local_rref_local_value(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name(self.rank)
+        rref = rpc.remote(dst_worker_name, return_value, (5,), {})
+
+        ret = rref_local_value(rref)
+        self.assertEqual(ret, 5)
+
+    def _create_rref(self):
+        owner_rank = (self.rank + 2) % self.world_size
+        return rpc.remote(
+            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
+        )
+
+    @dist_init
+    def test_user_rrefs_confirmed(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
+        )
+        self.assertEqual(ret, True)
+
+    @dist_init
+    def test_user_rrefs_confirmed_remote(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret_rref = rpc.remote(
+            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
+        )
+        self.assertEqual(ret_rref.to_here(), True)
+
+    @dist_init
+    def test_rref_list_mutate(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        list_rref = rpc.remote(dst, list_create)
+
+        rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,))
+        self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6])
+
+
+@torch.jit.script
+def no_arg():
+    return 0
+
+
+@torch.jit.script
+def one_arg(value):
+    return value + 1
+
+
+@torch.jit.script
+def script_add_ones(x):
+    return torch.add(x, torch.ones(1))
+
+
+@torch.jit.script
+def script_add_ones_with_record_function(x, block: str):
+    with record_function(block):
+        return torch.add(x, torch.ones(1))
+
+
+@torch.jit.script
+def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
+    t: Tensor = torch.ones(1)
+    with record_function(block):
+        fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
+        # Extra operator call to avoid de-duplication of the next async call
+        # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
+        zero = torch.zeros_like(t)
+        fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
+        res = fut1.wait() + fut2.wait() + zero
+    return res
+
+
+@torch.jit.script
+def script_fork_wait_udf(tensor):
+    fut = torch.jit._fork(script_add_ones, tensor)
+    x = torch.jit._wait(fut)
+    return x
+
+
+@torch.jit.script
+def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_var.to_here()
+
+
+@torch.jit.script
+def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]:
+    return rref_var
+
+
+@torch.jit.script
+def script_raise_func(value):
+    if value.numel() == 2:
+        raise ValueError("Expected error")
+    return value + 1
+
+
+@torch.jit.script
+def script_fork_wait_throw(invalue):
+    fut = torch.jit._fork(script_raise_func, invalue)
+    value = torch.jit._wait(fut)
+    return value
+
+
+@torch.jit.script
+def call_rpc_with_profiling(
+    record: torch.classes.profiler._RecordFunction, dst_worker_name: str
+) -> Tensor:
+    # Call rpc_async from within ScriptFunction and ensure that we can attach
+    # profiling callbacks. Note that handle here is a Tensor representation of
+    # RecordFunction.
+    fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
+    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def call_rpc_torchscript_with_record_function(
+    dst_worker_name: str, block: str
+) -> Tensor:
+    fut = rpc.rpc_async(
+        dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block)
+    )
+    return fut.wait()
+
+
+@torch.jit.script
+def call_fork_with_profiling(record: torch.classes.profiler._RecordFunction) -> Tensor:
+    # Call fork from within ScriptFunction and ensure that we can attach profiling
+    # callbacks to the resulting future. Note that handle here is a Tensor
+    # representation of RecordFunction.
+    fut = torch.jit._fork(one_arg, torch.tensor(1))
+    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+    ret = fut.wait()
+    return ret
+
+
+class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
+    def __init__(self, dst_worker):
+        super().__init__()
+        self.rrefs = []
+        for _ in range(4):
+            self.rrefs.append(rpc_return_rref(dst_worker))
+
+    @torch.jit.script_method
+    def forward(self) -> Tensor:
+        res_tensor = torch.ones(2, 2)
+        for rref in self.rrefs:
+            res_tensor += rref.to_here()
+
+        return res_tensor
+
+
+@torch.jit.ignore
+def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]:
+    return rref_var
+
+
+@torch.jit.script
+def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_python_annotation(rref_var).to_here()
+
+
+class RRefTypingTest:
+    @dist_init
+    def test_rref_as_arg_and_return(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        local_ret = one_arg(torch.ones(2, 2))
+
+        # create rref on current rank
+        rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),))
+
+        # pass rref to another user in rpc call
+        ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,))
+        self.assertEqual(ret, local_ret)
+
+        # return rref in rpc call
+        rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,))
+        self.assertEqual(rref1.to_here(), local_ret)
+
+        # pass rref to another user in remote call
+        rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,))
+        self.assertEqual(rref2.to_here(), local_ret)
+
+        # return rref in remote call
+        rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,))
+        self.assertEqual(rref3.to_here().to_here(), local_ret)
+
+    @dist_init
+    def test_my_script_module_with_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank))
+        res = module_with_rrefs()
+        self.assertEqual(res, torch.ones(2, 2) * 9)
+
+    @dist_init
+    def test_rref_python_annotation(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_var = rpc_return_rref(worker_name(dst_rank))
+
+        res = rref_script_annotation(rref_var)
+        self.assertEqual(res, torch.ones(2, 2) + 1)
+
+
+class FutureTypingTest:
+    @dist_init
+    def test_future_passed_between_python_and_jit(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        inputs = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs)
+        expected_res = torch.tensor([10, 10])
+
+        @torch.jit.script
+        def future_wait_in_script(fut: Future[Tensor]) -> Tensor:
+            return fut.wait()
+
+        self.assertEqual(future_wait_in_script(ret_fut), expected_res)
+
+        @torch.jit.script
+        def future_return_to_python(
+            dst_rank: int, inputs: tuple[Tensor, Tensor]
+        ) -> Future[Tensor]:
+            return rpc.rpc_async(f"worker{dst_rank}", two_args_two_kwargs, inputs)
+
+        fut_res = future_return_to_python(dst_rank, inputs)
+        self.assertEqual(fut_res.wait(), expected_res)
+
+    @dist_init
+    def test_future_python_annotation(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        input_0 = torch.ones(2, 2)
+        input_1 = 1
+        expected_res = torch.add(input_0, input_1)
+
+        @torch.jit.ignore
+        def python_return_future() -> Future[Tensor]:
+            fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {})
+            return fut
+
+        @torch.jit.script
+        def script_use_future() -> Tensor:
+            fut = python_return_future()
+            return fut.wait()
+
+        res = script_use_future()
+        self.assertEqual(res, expected_res)
+
+
+@torch.jit.script
+class MyScriptClass:
+    def __init__(self, a: int):
+        self.a = a
+
+    def get_value(self) -> int:
+        return self.a
+
+
+@torch.jit.interface
+class MyModuleInterface(torch.nn.Module):
+    def forward(self) -> Tensor:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+
+class MyScriptModule(torch.jit.ScriptModule):
+    def __init__(self, rank):
+        super().__init__()
+        self.a = torch.ones(rank)
+
+    @torch.jit.script_method
+    def forward(self) -> Tensor:
+        return self.a
+
+    @torch.jit.script_method
+    def custom_func(self) -> Tensor:
+        return self.a
+
+
+def owner_create_rref_my_script_class(a):
+    return rpc.RRef(MyScriptClass(a))
+
+
+def owner_create_rref_my_script_module(a):
+    return rpc.RRef(MyScriptModule(a), type_hint=MyModuleInterface)
+
+
+@torch.jit.script
+def script_rref_get_value_my_script_class(rref: RRef[MyScriptClass]) -> int:
+    return rref.to_here().get_value()
+
+
+@torch.jit.script
+def script_rref_run_forward_my_script_module(rref: RRef[MyModuleInterface]) -> Tensor:
+    return rref.to_here().forward()
+
+
+class LocalRRefTest:
+    @dist_init
+    def test_create_local_script_class_rref_in_py(self):
+        if self.rank != 0:
+            return
+
+        # Create a local RRef.
+        rref_script_class = rpc.RRef(MyScriptClass(self.rank))
+        ret = rref_script_class.to_here().get_value()
+        self.assertEqual(ret, self.rank)
+
+    @dist_init
+    def test_create_local_script_module_rref_in_py(self):
+        if self.rank != 0:
+            return
+
+        # Create a local RRef.
+        rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
+        ret = rref_script_module.to_here().forward()
+        self.assertEqual(ret, torch.ones(self.rank))
+
+        # Create a local RRef without type hint.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            (
+                "The RRef being created contains a ScriptModule, "
+                "must provide its ModuleInterface type hint."
+            ),
+        ):
+            rref_script_module = rpc.RRef(MyScriptModule(self.rank))
+
+    @dist_init
+    def test_return_local_script_class_rref_in_py_and_use_in_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Create a local RRef remotely in Python.
+        rref = rpc.rpc_sync(
+            dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,)
+        )
+
+        def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int:
+            args = (rref,)
+            kwargs: dict[str, Any] = {}
+            fut = rpc.rpc_async(
+                rref.owner(), script_rref_get_value_my_script_class, args, kwargs
+            )
+            ret = fut.wait()
+            return ret
+
+        # Use RRef in local Python RPC and remote Script run.
+        ret = use_rref_on_owner(rref)
+        self.assertEqual(ret, self.rank)
+
+        # Use RRef in local Script RPC and remote Script run.
+        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
+        ret = use_rref_on_owner_script(rref)
+        self.assertEqual(ret, self.rank)
+
+    @dist_init
+    def test_return_local_script_module_rref_in_py_and_use_in_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Create a local RRef remotely in Python.
+        rref = rpc.rpc_sync(
+            dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,)
+        )
+
+        def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor:
+            args = (rref,)
+            kwargs: dict[str, Any] = {}
+            fut = rpc.rpc_async(
+                rref.owner_name(),
+                script_rref_run_forward_my_script_module,
+                args,
+                kwargs,
+            )
+            ret = fut.wait()
+            return ret
+
+        # Use RRef in local Python RPC and remote Script run.
+        ret = use_rref_on_owner(rref)
+        self.assertEqual(ret, torch.ones(self.rank))
+
+        # Use RRef in local Script RPC and remote Script run.
+        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
+        ret = use_rref_on_owner_script(rref)
+        self.assertEqual(ret, torch.ones(self.rank))
+
+
+def python_function():
+    return 0
+
+
+@torch.jit.script
+def two_args_two_kwargs(
+    first_arg,
+    second_arg,
+    first_kwarg=torch.tensor([3, 3]),
+    second_kwarg=torch.tensor([4, 4]),
+):
+    return first_arg + second_arg + first_kwarg + second_kwarg
+
+
+@torch.jit.script
+def assorted_types_args_kwargs(
+    tensor_arg: Tensor,  # noqa: E999
+    str_arg: str,
+    int_arg: int,
+    tensor_kwarg: Tensor = torch.tensor([2, 2]),
+    str_kwarg: str = "str_kwarg",
+    int_kwarg: int = 2,
+):
+    return tensor_arg + tensor_kwarg, str_arg + str_kwarg, int_arg + int_kwarg
+
+
+@torch.jit.script
+def raise_script():
+    raise RuntimeError("Expected error")
+
+
+@torch.jit.script
+def script_rpc_async_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def script_rpc_sync_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return res
+
+
+@torch.jit.script
+def script_rpc_remote_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return rref_res.to_here()
+
+
+class JitRpcOpTest:
+    # Call functions remotely from Script.
+    @dist_init
+    def test_all_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {}
+
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([10, 10]))
+
+    @dist_init
+    def test_some_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {"first_kwarg": torch.tensor([2, 2])}
+
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([9, 9]))
+
+    @dist_init
+    def test_no_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([8, 8]))
+
+    @dist_init
+    def test_args_and_kwargs_contain_different_types(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_with_assorted_types(
+            dst_worker_name: str,
+        ):
+            args = (torch.tensor([1, 1]), "str_arg", 1)
+            # Must annotate the value type as `Any`, because JIT type inference
+            # does not support multiple types when defining a Dict.
+            # The error JIT gives is,
+            # "Dict values must contain only a single type, "
+            # "expected: Tensor but found str instead."
+            kwargs: dict[str, Any] = {
+                "tensor_kwarg": torch.tensor([3, 3]),
+                "str_kwarg": "_str_kwarg",
+                "int_kwarg": 3,
+            }
+            fut = rpc.rpc_async(
+                dst_worker_name, assorted_types_args_kwargs, args, kwargs
+            )
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_with_assorted_types(dst_worker_name)
+        self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4))
+
+    @dist_init
+    def test_kwargs_not_passed(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_without_kwargs_passed(
+            dst_worker_name: str,
+        ):
+            args = ()
+            fut = rpc.rpc_async(dst_worker_name, no_arg, args)
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_without_kwargs_passed(dst_worker_name)
+        self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_args_kwargs_are_neither_passed(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_without_args_kwargs_passed(
+            dst_worker_name: str,
+        ):
+            fut = rpc.rpc_async(dst_worker_name, no_arg)
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name)
+        self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_less_than_needed_args_are_specified(self):
+        if self.rank != 0:
+            return
+
+        # Notice, args matching happens during scripting.
+        with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"):
+
+            @torch.jit.script
+            def script_rpc_async_call_with_less_args(
+                dst_worker_name: str,  # noqa: E999
+            ):
+                args = (torch.tensor([1, 1]),)
+                kwargs = {}
+                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+                ret = fut.wait()
+                return ret
+
+    @dist_init
+    def test_more_than_needed_args_are_specified(self):
+        if self.rank != 0:
+            return
+
+        # Notice, args matching happens during scripting.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Expected at most 4 arguments but found 5 positional arguments",
+        ):
+
+            @torch.jit.script
+            def script_rpc_async_call_with_more_args(
+                dst_worker_name: str,
+            ):
+                args = (
+                    torch.tensor([1, 1]),
+                    torch.tensor([2, 2]),
+                    torch.tensor([3, 3]),
+                    torch.tensor([4, 4]),
+                    torch.tensor([5, 5]),
+                )
+                kwargs = {}
+                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+                ret = fut.wait()
+                return ret
+
+    @dist_init
+    def test_unexepected_kwarg_is_specified(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Notice, kwargs matching happens during execution.
+        @torch.jit.script
+        def script_rpc_async_call_with_unexpected_kwarg(
+            dst_worker_name: str,  # noqa: E999
+        ):
+            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+            kwargs = {"third_kwarg": torch.tensor([1, 1])}
+            fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Unknown keyword argument 'third_kwarg'"
+        ):
+            ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name)
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_python_function_remotely_from_script_not_supported(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "attempted to get undefined function"
+        ):
+            ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name)
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_script_function_that_raises_remotely_from_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Notice, TorchScript always translates(emits) Python `raise` statement,
+        # as the exception message string, "Exception",
+        # no matter what exception type and exception message are in the statement,
+        @torch.jit.script
+        def rpc_async_call_remote_raising_torchscript_in_torchscript(
+            dst_worker_name: str,
+        ):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(RuntimeError, "Expected error"):
+            ret = rpc_async_call_remote_raising_torchscript_in_torchscript(
+                dst_worker_name
+            )
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_script_function_that_not_exists_remotely_from_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def nonexisting_script():
+            return 0
+
+        @torch.jit.script
+        def rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
+            dst_worker_name: str,
+        ):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, nonexisting_script, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "attempted to get undefined function nonexisting_script"
+        ):
+            ret = rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
+                dst_worker_name
+            )
+            self.assertEqual(ret, 0)
+
+
+@torch.jit.ignore
+def my_script_module_init(rank: int) -> MyModuleInterface:
+    return MyScriptModule(rank)
+
+
+@torch.jit.script
+def construct_my_script_module(rank: int) -> MyModuleInterface:
+    return my_script_module_init(rank)
+
+
+@torch.jit.script
+def run_ref_script_module(
+    ref_script_module: RRef[MyModuleInterface], t: Tensor
+) -> Tensor:
+    module = ref_script_module.to_here()
+    return module.forward() + t
+
+
+@torch.jit.script
+def script_check_rref_confirmed(rref: RRef[Tensor]) -> bool:
+    return rref.confirmed_by_owner()
+
+
+@torch.jit.script
+def save_rref(rref_var: RRef[Tensor], fname: str) -> None:
+    torch.save(rref_var, fname)
+
+
+@torch.jit.script
+def script_add(x: Tensor, y: Tensor) -> Tensor:
+    return x + y
+
+
+@rpc.functions.async_execution
+@torch.jit.script
+def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
+    return rpc.rpc_async(to, script_add, (x, y))
+
+
+@rpc.functions.async_execution
+@torch.jit.script
+def async_wrong_type() -> Tensor:
+    return torch.zeros(2)
+
+
+def load_script_module_with_pickled_rref(pickled_script_module):
+    f = io.BytesIO(pickled_script_module)
+    m = torch.jit.load(f)
+    return m()
+
+
+class JitRpcTest(
+    RRefAPITest,
+    RRefTypingTest,
+    LocalRRefTest,
+    JitRpcOpTest,
+    FutureTypingTest,
+    RpcAgentTestFixture,
+):
+    @dist_init
+    def test_torchscript_function(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        local_ret = one_arg(torch.ones(2, 2))
+        ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
+        self.assertEqual(ret, local_ret)
+        rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
+        self.assertEqual(rref.to_here(), local_ret)
+        # create rref to itself
+        local_rref = rpc.remote(
+            worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)
+        )
+        self.assertEqual(local_rref.to_here(), local_ret)
+
+    @dist_init
+    def test_torchscript_function_exception(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
+            rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20))
+
+        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
+            rpc.remote(dst_worker_name, one_arg, args=(10, 20))
+
+    @dist_init
+    def test_torchscript_functions_not_supported(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        my_local_script_module = MyScriptModule(self.rank)
+
+        # It is not thread safe to instantiate MyScriptModule in multiple threads,
+        # wait for local MyScriptModule instantiation to finish,
+        # otherwise it could instantiate MyScriptModule in parallel with
+        # server thread in the below
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        # rpc_sync still accepts script class and run it in
+        # the same code path as python call.
+        rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,))
+
+        # rpc_sync does not accept script module method.
+        # Python 3.5 and Python 3.6 throw different error message, the only
+        # common word can be greped is "pickle".
+        with self.assertRaisesRegex(TypeError, "pickle"):
+            rpc.rpc_async(dst_worker_name, my_local_script_module.forward, args=())
+
+    @dist_init
+    def test_remote_script_module(self):
+        # TODO, need more investigation
+        # there is rref leak when shutting down, suspect it is because
+        # ref as arg is passed to pybind boundary, and the ref is not garbage
+        # collected by python when calling shutdown()
+        import torch.distributed.rpc.api as api
+
+        api._ignore_rref_leak = True
+
+        local_ret = torch.ones(self.rank) + torch.ones(self.rank)
+
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        remote_ref = rpc.remote(
+            worker_name(dst_rank), construct_my_script_module, args=(self.rank,)
+        )
+
+        # pass rref arg to owner
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            run_ref_script_module,
+            args=(remote_ref, torch.ones(self.rank)),
+        )
+        self.assertEqual(ret, local_ret)
+
+        # pass rref arg to self/user
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "is an RRef to a ScriptModule. It can't be sent through RPC from owner,",
+        ):
+            ret = rpc.rpc_sync(
+                worker_name(self.rank),
+                run_ref_script_module,
+                args=(remote_ref, torch.ones(self.rank)),
+            )
+
+    @dist_init
+    def test_create_script_module_on_remote(self):
+        dst_name = worker_name((self.rank + 1) % self.world_size)
+        # Construct on remote end with rpc_sync
+        created_script_module = rpc.rpc_sync(
+            dst_name, MyScriptModule, args=(self.rank,)
+        )
+        # Forward should output a ones tensor of self.rank.
+        self.assertTrue(isinstance(created_script_module, torch.jit.ScriptModule))
+        rank_ones_tensor = created_script_module()
+        self.assertEqual(torch.ones(self.rank), rank_ones_tensor)
+
+        # Construct ScriptModule with rpc.remote.
+        remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank,))
+        # Verify it is an instance of ScriptModule on remote end.
+        remote_end_is_script = rpc.rpc_sync(
+            remote_script_module.owner(),
+            rref_isinstance,
+            args=(remote_script_module, torch.jit.ScriptModule),
+        )
+        self.assertTrue(remote_end_is_script)
+        # Run forward pass remotely.
+        remote_forward_output = remote_script_module.rpc_sync().forward()
+        self.assertEqual(remote_forward_output, torch.ones(self.rank))
+        # Run function defined on ScriptModule remotely.
+        remote_func_output = remote_script_module.rpc_sync().custom_func()
+        self.assertEqual(remote_func_output, torch.ones(self.rank))
+        # Ensure we can transfer ScriptModule RRef to this rank and run
+        # forward pass.
+        local_script_module = remote_script_module.to_here()
+        self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule))
+        rank_ones_tensor = local_script_module()
+        self.assertEqual(rank_ones_tensor, torch.ones(self.rank))
+        local_script_func_output = local_script_module.custom_func()
+        self.assertEqual(local_script_func_output, torch.ones(self.rank))
+
+    @dist_init
+    def test_load_script_module_with_pickled_rref(self):
+        dst_name = worker_name((self.rank + 1) % self.world_size)
+        m1 = MyScriptModuleWithRRefs(dst_name)
+        m2 = MyScriptModuleWithRRefs(dst_name)
+
+        f = io.BytesIO()
+
+        rpc._enable_jit_rref_pickle()
+        torch.jit.save(m1, f)
+        rpc._disable_jit_rref_pickle()
+
+        out1 = rpc.rpc_sync(
+            dst_name, load_script_module_with_pickled_rref, args=(f.getvalue(),)
+        )
+        out2 = m2()
+        self.assertEqual(out1, out2)
+
+    @dist_init
+    def test_rref_jit_pickle_not_supported(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_var = rpc_return_rref(worker_name(dst_rank))
+        with (
+            TemporaryFileName() as fname,
+            self.assertRaisesRegex(
+                RuntimeError, "RRef jit pickling is only allowed inside RPC calls"
+            ),
+        ):
+            save_rref(rref_var, fname)
+
+    @dist_init
+    def test_remote_script_throw(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            script_raise_func,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            rref.to_here()
+
+    @dist_init
+    def test_remote_script_udf(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2) * 2)
+
+    @dist_init
+    def test_async_script_udf(self):
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+        self.assertEqual(future.wait(), torch.ones(2) * 2)
+
+    @dist_init
+    def test_callback_simple(self):
+        def callback(fut):
+            return fut.wait() + 1
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        ).then(callback)
+        self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)
+
+    @dist_init
+    def test_callback_chain(self):
+        n = self.rank + 1
+
+        def callback(fut):
+            return fut.wait() + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size), one_arg, args=(torch.ones(n, n),)
+        )
+
+        num_cbs = 20
+        for _ in range(num_cbs):
+            fut = fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
+
+    @dist_init
+    def test_add_done_callback(self):
+        callback_called = None
+
+        def callback(fut):
+            nonlocal callback_called
+            callback_called = fut.wait() * 2
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+
+        future.add_done_callback(callback)
+        future_then = future.then(lambda _: True)
+
+        self.assertEqual(future.wait(), torch.ones(2) * 2)
+
+        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
+        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
+        future_then.wait()
+        self.assertEqual(callback_called, torch.ones(2) * 4)
+
+    @dist_init
+    def test_async_script_throw(self):
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_throw,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            future.wait()
+
+    @dist_init
+    def test_callback_with_exception(self):
+        def callback(fut):
+            with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+                fut.wait()
+            raise RuntimeError("Another expected error")
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_throw,
+            args=(torch.ones(2),),
+        ).then(callback)
+
+        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
+            future.wait()
+
+    @dist_init
+    def test_call_rpc_with_profiling(self):
+        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
+        # future from within a script function that calls rpc_async
+        if self.rank == 0:
+            with _profile() as prof:
+                prof_key = _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC,
+                    torch._jit_internal._qualified_name(one_arg),
+                    "worker0",
+                    "worker1",
+                )
+                with torch.autograd.profiler.record_function(prof_key) as rf:
+                    call_rpc_with_profiling(rf.record, "worker1")
+            # TODO: Can't get a reliable time for this profiling event since
+            # it's hard to estimate the execution time on the remote end for non-UDFs.
+            # This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
+            # After that, this test should be modified to validate the function time.
+            events = prof.function_events
+            function_event = get_function_event(events, prof_key)
+            self.assertTrue(
+                torch._jit_internal._qualified_name(one_arg) in function_event.name
+            )
+
+    @dist_init
+    def test_rpc_async_jit_profiled(self):
+        # Tests that rpc_async calls made from within a TorchScript function are
+        # profiled.
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+            kwargs = {}
+            with _profile() as prof:
+                script_rpc_async_call(dst_worker_name, args, kwargs)
+
+            # Ensure rpc_async call is profiled
+            function_events = prof.function_events
+            qual_name = torch._jit_internal._qualified_name(two_args_two_kwargs)
+            rpc_async_jit_event = [
+                event
+                for event in function_events
+                if qual_name in event.name and event.node_id == self.rank
+            ]
+            self.assertEqual(len(rpc_async_jit_event), 1)
+            rpc_async_jit_event = rpc_async_jit_event[0]
+            profiled_name = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC_JIT,
+                qual_name,
+                worker_name(self.rank),
+                dst_worker_name,
+            )
+            self.assertEqual(profiled_name, rpc_async_jit_event.name)
+            remote_events = [event for event in function_events if event.is_remote]
+            # All remote events should have taken place on dst_rank
+            remote_event_node_ids = {
+                remote_event.node_id for remote_event in remote_events
+            }
+            self.assertEqual(remote_event_node_ids, {dst_rank})
+            # script_rpc_async_call invokes add operator
+            # so we should see this as a remote event.
+            remote_add = next(
+                remote_event
+                for remote_event in remote_events
+                if "aten::add" in remote_event.name
+            )
+            remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add"
+            self.assertEqual(remote_add.name, remote_add_profiled_name)
+
+    @dist_init
+    def test_record_function_on_caller_rpc_async(self):
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            block_scope = "foo"
+            with _profile() as prof:
+                # Runs 2 rpc_async calls within JIT under record_function.
+                record_function_on_caller_rpc_async(dst_worker_name, block_scope)
+
+            # Ensure record_function event is profiled.
+            function_events = prof.function_events
+            record_function_scope_event = [
+                event for event in function_events if event.name == block_scope
+            ]
+            self.assertEqual(1, len(record_function_scope_event))
+            record_function_scope_event = record_function_scope_event[0]
+            # Ensure RPC future is profiled.
+            expected_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC_JIT,
+                torch._jit_internal._qualified_name(script_add_ones),
+                worker_name(self.rank),
+                dst_worker_name,
+            )
+            jit_rpc_events = [
+                event for event in function_events if event.name == expected_key
+            ]
+            self.assertEqual(2, len(jit_rpc_events))
+            # Validate that the record_function scope time is greater than both
+            # of the individual RPC async call times. The reason it is not necessarily
+            # greater than the sum is because the two can execute in parallel.
+            for jit_rpc_event in jit_rpc_events:
+                self.assertTrue(
+                    record_function_scope_event.cpu_time_total
+                    > jit_rpc_event.cpu_time_total
+                )
+
+    @dist_init
+    def test_rpc_torchscript_record_function(self):
+        # tests that torchscript functions can be profiled using with
+        # record_function(...) over RPC.
+        REMOTE_OP_STR = "#remote_op: "
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            block_scope = "foo"
+            with _profile() as prof:
+                call_rpc_torchscript_with_record_function(dst_worker_name, block_scope)
+
+            # Need to call below to populate CPU children.
+            prof.key_averages()
+            function_events = prof.function_events
+            expected_key = (
+                _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC_JIT,
+                    torch._jit_internal._qualified_name(
+                        script_add_ones_with_record_function
+                    ),
+                    worker_name(self.rank),
+                    dst_worker_name,
+                )
+                + REMOTE_OP_STR
+                + block_scope
+            )
+            remote_record_function_event = next(
+                evt for evt in function_events if evt.name == expected_key
+            )
+            self.assertTrue(block_scope in remote_record_function_event.name)
+            remote_children = remote_record_function_event.cpu_children
+            self.assertTrue("aten::add" in child.name for child in remote_children)
+
+    def test_record_function_jit_end_callbacks_with_fork(self):
+        # Ensures that we can call rf._call_end_callbacks_on_future on a jit
+        # future in python eager mode with torch.jit.fork
+        sleep_interval = 1
+        with _profile() as prof:
+            with torch.autograd.profiler.record_function("foo") as rf:
+                fut = torch.jit._fork(sleep, sleep_interval)
+                rf._call_end_callbacks_on_future(fut)
+            fut.wait()
+
+        function_events = prof.function_events
+        sleep_event = get_function_event(function_events, "foo")
+        self.assertEqual(sleep_event.name, "foo")
+        # Validate that callbacks were fired at the right time by checking the
+        # profiling event cpu time
+        self.assertGreaterAlmostEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
+
+    def test_call_fork_in_jit_with_profiling(self):
+        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
+        # future from within a script function with torch.jit.fork
+        with _profile() as prof, torch.autograd.profiler.record_function("foo") as rf:
+            call_fork_with_profiling(rf.record)
+
+        events = prof.function_events
+        function_event = get_function_event(events, "foo")
+        self.assertEqual(function_event.name, "foo")
+
+    @dist_init
+    def test_async_function_simple(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(
+            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_async_function_wrong_return_type(self):
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
+            rpc.rpc_sync(
+                worker_name((self.rank + 1) % self.world_size), async_wrong_type
+            )
+
+    @dist_init
+    def test_async_function_wrong_decorator_order(self):
+        # @torch.jit.script complains about undefined value rpc. Error is shown
+        # below. The reason for not checking error string is to avoid making
+        # JIT error handling code depend on RPC tests, as we don't have any
+        # restrictions on the error message here.
+        #
+        # RuntimeError:
+        # undefined value rpc:
+        # def async_wrong_decorator_order(to, x, y):
+        #    # type: (str, Tensor, Tensor) -> Future[Tensor]
+        #    return rpc.rpc_async(to, script_add, (x, y))
+        #           ~~~ <--- HERE
+        with self.assertRaises(RuntimeError):
+
+            @torch.jit.script
+            @rpc.functions.async_execution
+            def async_wrong_decorator_order(
+                to: str, x: Tensor, y: Tensor
+            ) -> Future[Tensor]:
+                return rpc.rpc_async(to, script_add, (x, y))
+
+    @dist_init
+    def test_async_function_remote(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        rref = rpc.remote(
+            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_async_function_remote_multi(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        num = 20
+        rrefs = [
+            rpc.remote(
+                dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
+            )
+            for i in range(num)
+        ]
+
+        for i in range(num):
+            self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i)
+
+    @dist_init
+    def test_async_function_wrong_return_type_remote(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size), async_wrong_type
+        )
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
+            rref.to_here()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bedaad32d0e904a9a7523f31eced9cef96e832d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
@@ -0,0 +1,219 @@
+# mypy: allow-untyped-defs
+
+
+import torch
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.distributed.rpc import RRef
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+@torch.jit.script
+def two_args_two_kwargs(
+    first_arg,
+    second_arg,
+    first_kwarg=torch.tensor([3, 3]),
+    second_kwarg=torch.tensor([4, 4]),
+):
+    return first_arg + second_arg + first_kwarg + second_kwarg
+
+
+@torch.jit.script
+def script_rpc_async_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def rpc_async_call_with_timeout(
+    dst_worker_name: str,
+    args: tuple[Tensor, Tensor],
+    kwargs: dict[str, Tensor],
+    timeout: float,
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def rpc_async_call_with_timeout_future_ret(
+    dst_worker_name: str,
+    args: tuple[Tensor, Tensor],
+    kwargs: dict[str, Tensor],
+    timeout: float,
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
+    return fut
+
+
+@torch.jit.script
+def rpc_async_call_future_ret(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return fut
+
+
+@torch.jit.script
+def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_var.to_here()
+
+
+@torch.jit.script
+def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor:
+    return rref_var.to_here(timeout)
+
+
+@torch.jit.script
+def rpc_async_with_rref_arg(dst_worker_name: str, args: tuple[RRef[Tensor]]) -> Tensor:
+    fut = rpc.rpc_async(dst_worker_name, rref_to_here, args)
+    ret = fut.wait()
+    return ret
+
+
+class JitFaultyAgentRpcTest(RpcAgentTestFixture):
+    """
+    Run tests for rpc_async in JIT under the faulty agent test fixture to test
+    arbitrary timeouts.
+    """
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_timeout_in_torchscript_function(self):
+        # Call rpc_async + fut.wait() in torchscript function and ensure that
+        # timeout is raised.
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        expected_error = self.get_timeout_error_regex()
+        # Ensure that we get a timeout if we override the default timeout and
+        # the RPC takes longer to execute.
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)
+
+        # Ensure that we timeout if we don't specify a timeout but the default
+        # is less than the RPC takes to execute.
+        rpc._set_rpc_timeout(0.001)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            script_rpc_async_call(dst_worker_name, args, kwargs)
+
+        # Ensure that we run to completion if zero timeout is specified.
+        ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
+        self.assertEqual(ret, torch.tensor([8, 8]))
+        # reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_timeout_in_python(self):
+        # Ensures timeouts are raised if we call rpc_async from within a
+        # torchscript function, but wait on the future in python.
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        expected_error = self.get_timeout_error_regex()
+
+        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure timeout if we don't specify but the default is less than the
+        # RPC takes to execute.
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if zero timeout is specified
+        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
+        result = fut.wait()
+        self.assertEqual(result, torch.tensor([8, 8]))
+        # reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_remote_timeout_to_here_in_jit(self):
+        # Test that calling to_here() in JIT will raise timeout error if
+        # rpc.remote failed.
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call to_here() within a ScriptFunction and ensure it raises
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref_to_here(rref)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
+    def test_rref_to_here_timeout_in_jit(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref_to_here_with_timeout(rref, 0.01)
+
+        rref_to_here_with_timeout(rref, 100)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_rref_timeout_pickle_in_jit(self):
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call RPC with RRef arg in JIT, which will go through JIT pickling and
+        # ensure error is raised.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc_async_with_rref_arg(dst_worker, (rref,))
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_rref_timeout_pickle_script_func(self):
+        # Similar to above test, but calls python rpc with script function.
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call RPC with script function that takes RRef, ensure timeout during pickling
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc.rpc_sync(dst_worker, rref_to_here, args=(rref,))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a684b73d2f315a00465371fad3050a795251ddb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
@@ -0,0 +1,63 @@
+# mypy: allow-untyped-defs
+
+import os
+from abc import ABC, abstractmethod
+
+import torch.testing._internal.dist_utils
+
+
+class RpcAgentTestFixture(ABC):
+    @property
+    def world_size(self) -> int:
+        return 4
+
+    @property
+    def init_method(self):
+        use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+        if use_tcp_init == "1":
+            master_addr = os.environ["MASTER_ADDR"]
+            master_port = os.environ["MASTER_PORT"]
+            return f"tcp://{master_addr}:{master_port}"
+        else:
+            return self.file_init_method
+
+    @property
+    def file_init_method(self):
+        return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(
+            file_name=self.file_name
+        )
+
+    @property
+    @abstractmethod
+    def rpc_backend(self):
+        pass
+
+    @property
+    @abstractmethod
+    def rpc_backend_options(self):
+        pass
+
+    def setup_fault_injection(self, faulty_messages, messages_to_delay):  # noqa: B027
+        """Method used by dist_init to prepare the faulty agent.
+
+        Does nothing for other agents.
+        """
+
+    # Shutdown sequence is not well defined, so we may see any of the following
+    # errors when running tests that simulate errors via a shutdown on the
+    # remote end.
+    @abstractmethod
+    def get_shutdown_error_regex(self):
+        """
+        Return various error message we may see from RPC agents while running
+        tests that check for failures. This function is used to match against
+        possible errors to ensure failures were raised properly.
+        """
+
+    @abstractmethod
+    def get_timeout_error_regex(self):
+        """
+        Returns a partial string indicating the error we should receive when an
+        RPC has timed out. Useful for use with assertRaisesRegex() to ensure we
+        have the right errors during timeout.
+        """
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c50aadc058cbdd2d5e08b4df711572828b2f2ee9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -0,0 +1,6312 @@
+# mypy: allow-untyped-defs
+
+import concurrent.futures
+import contextlib
+import json
+import operator
+import os
+import sys
+import threading
+import time
+from collections import namedtuple
+from functools import partial
+from threading import Event, Lock
+from unittest import mock
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+from torch.autograd.profiler_legacy import profile as _profile
+from torch.distributed.rpc import (
+    _get_debug_info,
+    _rref_context_get_debug_info,
+    RRef,
+    WorkerInfo,
+)
+from torch.distributed.rpc.api import _thread_local_var, _use_rpc_pickler, _wait_all
+from torch.distributed.rpc.internal import (
+    _build_rpc_profiling_key,
+    _internal_rpc_pickler,
+    PythonUDF,
+    RPCExecMode,
+)
+from torch.futures import Future
+from torch.testing._internal.common_distributed import (
+    captured_output,
+    skip_if_lt_x_gpu,
+    tp_transports,
+)
+from torch.testing._internal.common_utils import (
+    get_cycles_per_ms,
+    IS_MACOS,
+    load_tests,
+    skip_but_pass_in_sandcastle_if,
+    TemporaryFileName,
+)
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    get_function_event,
+    initialize_pg,
+    wait_until_node_failure,
+    wait_until_owners_and_forks_on_rank,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def foo_add():
+    return torch.add(torch.ones(1), torch.ones(1))
+
+
+def udf_with_torch_ops(device=-1, use_record_function=False):
+    device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device)
+    record_function_ctx = (
+        torch.autograd.profiler.record_function("##forward##")
+        if use_record_function
+        else contextlib.nullcontext()
+    )
+    with device_ctx, record_function_ctx:
+        t1, t2 = torch.ones(1), torch.ones(1)
+        t = torch.add(t1, t2)
+        t = torch.mul(t, t)
+        t = t.relu()
+        t = t.sigmoid()
+
+
+# Events (operator invocations) that are expected to be ran as part of the above
+# function.
+EXPECTED_REMOTE_EVENTS = [
+    "aten::ones",
+    "aten::ones",
+    "aten::add",
+    "aten::mul",
+    "aten::relu",
+    "aten::clamp_min",
+    "aten::sigmoid",
+]
+
+# Remote operations are prefixed with the following string for RPC profiling.
+REMOTE_OP_STR = "#remote_op: "
+
+
+VALUE_FUTURE = concurrent.futures.Future()
+DONE_FUTURE = concurrent.futures.Future()
+
+FIFTY_MIL_CYCLES = 50000000
+
+_rpc_barrier_count = 0
+
+
+def _increment_count():
+    global _rpc_barrier_count
+    _rpc_barrier_count += 1
+
+
+def _reset_count():
+    global _rpc_barrier_count
+    _rpc_barrier_count = 0
+
+
+class StubRpcAgent:
+    def __init__(self, world_size):
+        self.world_size = world_size
+
+    def get_worker_infos(self):
+        return {
+            WorkerInfo(name=worker_name(rank), id=rank)
+            for rank in range(self.world_size)
+        }
+
+
+def _stub_construct_rpc_backend_options_handler(**kwargs):
+    return mock.Mock()  # RpcBackendOptions.
+
+
+def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options):
+    return StubRpcAgent(world_size=world_size)
+
+
+def set_value(value):
+    VALUE_FUTURE.set_result(value)
+
+
+def wait_for_value_future():
+    return VALUE_FUTURE.result()
+
+
+def set_and_check_done(value):
+    VALUE_FUTURE.set_result(value)
+    return DONE_FUTURE.result()
+
+
+# it is used to test python user defined function over rpc
+# classes and functions are used to test python user defined class and
+# methods over rpc
+TensorClass = namedtuple("TensorClass", ["tensors"])
+
+
+class MyPickleClass:
+    def __init__(self) -> None:
+        self.t = None
+
+    def __getstate__(self):
+        (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
+            PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
+        )
+        return (pickled_python_udf, tensors)
+
+    def __setstate__(self, obj):
+        python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
+        result = python_udf.func(python_udf.args[0], python_udf.args[1])
+        self.t = result
+
+    def set(self, val):
+        self.t = val
+
+
+class SlowPickleClass:
+    def __init__(self, t):
+        self.t = t
+
+    def __getstate__(self):
+        time.sleep(self.t)
+        return (self.t,)
+
+    def __setstate__(self, obj):
+        self.t = obj[0]
+        time.sleep(self.t)
+
+
+class MyClass:
+    def __init__(self, a, delay=False):
+        self.a = a
+        # delay initialization to simulate errors if specified
+        if delay:
+            time.sleep(2)
+
+    def my_instance_method(self, b):
+        return self.a + b
+
+    @classmethod
+    def my_class_method(cls, d, e):
+        return d + e
+
+    @staticmethod
+    def my_static_method(f):
+        return f > 10
+
+    def increment_value(self, increment):
+        self.a += increment
+
+    def get_value(self):
+        return self.a
+
+    def my_slow_method(self, my_tensor_arg):
+        time.sleep(5)
+        return torch.add(self.a, my_tensor_arg)
+
+
+def _call_method_on_rref(method, rref, *args, **kwargs):
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def get_rref_list(values):
+    return [RRef(MyClass(a)) for a in values]
+
+
+def add_rref_to_value(rref, value):
+    return rref.to_here() + value
+
+
+def run_nested_pickle(pickle_cls_instance, tensor):
+    return pickle_cls_instance.t + tensor
+
+
+def build_sparse_tensor(coalesce=False):
+    i = [[0, 1, 1], [2, 0, 2]]
+    v = [3, 4, 5]
+    tensor = torch.sparse_coo_tensor(i, v, (2, 3))
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
+
+
+def build_complex_tensors():
+    a = torch.ones(3, 3)
+    b = [a, a]
+    c = [b, b]
+    d = [a, b]
+    e = {a: d}
+    return [a, b, c, d, e]
+
+
+def non_cont_test(t_view, t_cont):
+    if t_view.is_contiguous():
+        raise Exception("t_view is contiguous!")  # noqa: TRY002
+    if not t_cont.is_contiguous():
+        raise Exception("t_cont is not contiguous!")  # noqa: TRY002
+    if not torch.equal(t_view, t_cont):
+        raise Exception("t_view is not equal to t_cont!")  # noqa: TRY002
+    return t_view
+
+
+def my_function(a, b, c):
+    return a + b + c
+
+
+def my_tensor_function(a, b):
+    return a + b
+
+
+def my_container_sum(a):
+    result = a[0]
+    for tensor in a[1:]:
+        result += tensor
+    return result
+
+
+def my_sleep_func(seconds=1):
+    time.sleep(seconds)
+    return torch.mul(torch.tensor(1), torch.tensor(1))
+
+
+def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
+    res = list_input[0]
+    for t in list_input:
+        res += t
+    for v in dict_input.values():
+        res += v
+    complex_tensors = tensor_class_input.tensors
+    return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
+
+
+def my_rref_function(rref_a, rref_b):
+    return rref_a.to_here() + rref_b.to_here()
+
+
+def delayed_add(a, b, seconds=0.05):
+    time.sleep(seconds)
+    return a + b
+
+
+def identity(a):
+    return a
+
+
+def no_result():
+    print("do nothing")
+
+
+def raise_or_inc(value):
+    if value.numel() == 2:
+        raise ValueError("Expected error")
+    return value + 1
+
+
+def nested_rpc(dst):
+    return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+def nested_rpc_sparse(dst):
+    return rpc.rpc_sync(
+        dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())
+    )
+
+
+def multi_layer_nested_async_rpc(dst, world_size, ttl):
+    # this method returns immediately without blocking the callee, but will
+    # generate additional requests.
+    if ttl > 0:
+        current_dst = worker_name(dst)
+        next_dst = (dst + 1) % world_size
+        rpc.rpc_async(
+            current_dst,
+            multi_layer_nested_async_rpc,
+            args=(next_dst, world_size, ttl - 1),
+        )
+        return 0
+
+
+def nested_rref(dst):
+    return (
+        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
+        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
+    )
+
+
+def nested_rref_sparse(dst):
+    return (
+        rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())),
+        rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())),
+    )
+
+
+def nested_remote(dst):
+    rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
+    return rref.to_here()
+
+
+def nested_remote_sparse(dst):
+    rref = rpc.remote(
+        dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())
+    )
+    return rref.to_here()
+
+
+def rref_forward_chain(dst, world_size, rref, ttl):
+    if ttl > 0:
+        current_dst = worker_name(dst)
+        next_dst = (dst + 1) % world_size
+        ret_rref = rpc.remote(
+            current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
+        )
+        return [ret_rref]
+    else:
+        return rref.to_here()
+
+
+def rpc_return_rref(dst):
+    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+def light_rpc():
+    return 0
+
+
+def heavy_rpc(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor /= i + 1
+    return 0
+
+
+def heavy_rpc_sparse(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor = tensor / (i + 1)
+    return 0
+
+
+@torch.jit.script
+def heavy_rpc_torchscript(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor /= i + 1
+    return 0
+
+
+@torch.jit.script
+def my_script_func(tensor):
+    return torch.add(tensor, tensor)
+
+
+expected_err = "Expected error"
+
+
+# Note that it needs to inherit from Exception, not BaseException. See comment
+# in rpc/internal.py
+class CustomException(Exception):
+    def __init__(self, bool, msg):
+        self.bool = bool
+        super().__init__(msg)
+
+
+def raise_func():
+    raise ValueError(expected_err)
+
+
+def custom_raise_func():
+    raise CustomException(True, "foo")
+
+
+@torch.jit.script
+def raise_func_script(expected_err: str) -> torch.Tensor:
+    raise ValueError(expected_err)
+
+
+expected_err_escape = (
+    "\nFirst line of error \n next line of error \n last line of error"
+)
+
+
+def raise_func_escape():
+    raise ValueError(expected_err_escape)
+
+
+global_rref = None
+
+
+def set_global_rref(rref):
+    global global_rref
+    global_rref = rref
+
+
+def clear_global_rref():
+    global global_rref
+    global_rref = None
+
+
+def check_rref_confirmed(rref):
+    return rref.confirmed_by_owner()
+
+
+def get_rref_debug_info():
+    return _rref_context_get_debug_info()
+
+
+def add_use_future_cb(to, x, y, z):
+    out = concurrent.futures.Future()
+
+    def callback(fut):
+        out.set_result(fut.wait() + z)
+
+    fut = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut.then(callback)
+    return out.result()
+
+
+def get_events_from_profile(profile_rref):
+    return profile_rref.local_value().process_global_function_events
+
+
+def add_use_future_set_result(to, x, y, z):
+    out = torch.futures.Future()
+    fut = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut.then(lambda fut: out.set_result(fut.wait() + z))
+    return out.wait()
+
+
+def add_use_future_nested_cb(to, x, y, z):
+    out = torch.futures.Future()
+
+    def callback(fut1):
+        fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z))
+        fut2.then(lambda fut2: out.set_result(fut2.wait()))
+
+    fut1 = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut1.then(callback)
+    return out.wait()
+
+
+def fail_on_fut(fut):
+    pass
+
+
+@rpc.functions.async_execution
+def async_raise_func():
+    raise RuntimeError("Expected error")
+
+
+@rpc.functions.async_execution
+def async_wrong_type():
+    return torch.zeros(2, 2)
+
+
+@rpc.functions.async_execution
+def async_add(to, x, y):
+    return rpc.rpc_async(to, torch.add, args=(x, y))
+
+
+def slow_add(x, y, device="cpu"):
+    time.sleep(1)
+    x = x.to(device)
+    y = y.to(device)
+    return torch.add(x, y).cpu()
+
+
+@rpc.functions.async_execution
+def slow_async_add(to, x, y, device="cpu"):
+    return rpc.rpc_async(to, slow_add, args=(x, y, device))
+
+
+@rpc.functions.async_execution
+def async_add_with_future_ctor(to, x, y, z):
+    fut = torch.futures.Future()
+    rpc.rpc_async(to, torch.add, args=(x, y)).then(
+        lambda fut1: fut.set_result(fut1.wait() + z)
+    )
+    return fut
+
+
+@rpc.functions.async_execution
+def async_add_chained(to, x, y, z):
+    return rpc.rpc_async(to, torch.add, args=(x, y)).then(lambda fut: fut.wait() + z)
+
+
+@rpc.functions.async_execution
+def async_add_chained_multi(to, x, num, step):
+    fut = rpc.rpc_async(to, torch.add, args=(x, 0))
+    for _ in range(num):
+        fut = fut.then(lambda fut: fut.wait() + step)
+    return fut
+
+
+@rpc.functions.async_execution
+def async_add_nested(to, x, y, z):
+    return rpc.rpc_async(to, async_add, args=(to, x, y)).then(
+        lambda fut: fut.wait() + z
+    )
+
+
+@rpc.functions.async_execution
+def async_add_multi_fanout(to, x, num, step):
+    futs = []
+    for i in range(num):
+        if i == 0:
+            futs.append(rpc.rpc_async(to, torch.add, args=(x, step)))
+        else:
+            futs.append(rpc.rpc_async(to, torch.add, args=(0, step)))
+
+    # TODO: use torch.futures.collect_all
+    lock = Lock()
+    state = {"cnt": 0, "ret": torch.zeros_like(x)}
+    ret_future = torch.futures.Future()
+
+    def inc_and_set(fut):
+        with lock:
+            state["cnt"] += 1
+            state["ret"] += fut.wait()
+            if state["cnt"] >= len(futs):
+                ret_future.set_result(state["ret"])
+
+    for fut in futs:
+        fut.then(inc_and_set)
+
+    return ret_future
+
+
+@rpc.functions.async_execution
+def async_cuda_sleep_and_set_to_one(t):
+    device = t.device
+    original_stream = torch.cuda.current_stream(device)
+    new_stream = torch.cuda.Stream(device)
+    new_stream.wait_stream(original_stream)
+    with torch.cuda.stream(new_stream):
+        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+        t.fill_(1)
+        fut = Future(devices=[device])
+        fut.set_result(t)
+        return fut
+
+
+@rpc.functions.async_execution
+def async_cuda_nested_add(to, x, y, z):
+    def cb(fut):
+        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+        return fut.value() + z
+
+    return rpc.rpc_async(to, torch.add, args=(x, y)).then(cb)
+
+
+# A custom Python class that contains a tensor, needed to see if we correctly
+# use the Python pickler to extract tensors from non-IValue-convertible types.
+class TensorWrapper:
+    __slots__ = ("tensor", "lock", "event", "thread")
+
+    def __init__(self, t):
+        self.tensor = t
+        # Add one non-picklable field, to ensure it's ignored/skipped.
+        self.lock = Lock()
+        self.event = torch.cuda.Event(enable_timing=True)
+        self.thread = threading.Thread()
+        self.thread.start()
+
+    def increase(self, v):
+        with self.lock:
+            self.tensor += v
+
+    def sum(self):
+        with self.lock:
+            self.event.record()
+            return self.tensor.sum()
+
+
+class AsyncExecutionClass:
+    @staticmethod
+    @rpc.functions.async_execution
+    def static_async_add(to, x, y, z):
+        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: fut.wait() + z
+        )
+
+    @classmethod
+    @rpc.functions.async_execution
+    def class_async_add(cls, to, x, y, z):
+        ret_fut = torch.futures.Future()
+        rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: ret_fut.set_result(fut.wait() + z)
+        )
+        return ret_fut
+
+    @rpc.functions.async_execution
+    def bound_async_add(self, to, x, y, z):
+        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: fut.wait() + z
+        )
+
+
+def return_future():
+    return torch.futures.Future()
+
+
+class FooBackendOptions(rpc.RpcBackendOptions):
+    def __init__(self, init_method):
+        # Must call the __init__ of the superclass (and do so directly,
+        # without using super()) because... pybind.
+        rpc.RpcBackendOptions.__init__(self)
+        self.init_method = init_method
+
+
+# load_tests from common_utils is used to automatically filter tests for
+# sharding on sandcastle. This line silences flake warnings
+load_tests = load_tests  # noqa: PLW0127
+
+
+class MyEmbeddingBagModel(torch.nn.Module):
+    def __init__(self, sparse):
+        super().__init__()
+        self.eb = torch.nn.EmbeddingBag(10, 10, sparse=sparse)
+
+    def forward(self, x):
+        return self.eb(x)
+
+
+class MyParameterServer:
+    def __init__(self, trainers):
+        self.lock = Lock()
+        self.trainers = trainers
+        self.iteration = 0
+        self.updates = 0
+        self.futures = []
+        self.total = None
+        self.gradient = None
+
+    @staticmethod
+    def get_gradient(rref):
+        return rref.local_value().gradient
+
+    @staticmethod
+    @rpc.functions.async_execution
+    def average(rref, riteration, tensor):
+        self = rref.local_value()
+        fut = torch.futures.Future()
+        with self.lock:
+            if riteration > self.iteration:
+                self.iteration = riteration
+                self.updates = 0
+                self.futures.clear()
+            self.futures.append(fut)
+            if self.total is None:
+                self.total = tensor
+            else:
+                self.total += tensor
+            self.updates += 1
+            if self.trainers == self.updates:
+                self.gradient = self.total / float(self.trainers)
+                for fut in self.futures:
+                    result = self.total / float(self.trainers)
+                    fut.set_result(result)
+        return fut
+
+
+class MyConvNetForMNIST(nn.Module):
+    def __init__(self, device):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Conv2d(1, 16, 3, 1),
+            nn.ReLU(),
+            nn.Conv2d(16, 32, 3, 1),
+            nn.ReLU(),
+            nn.MaxPool2d(2),
+            nn.Flatten(1),
+            nn.Linear(4608, 128),
+            nn.ReLU(),
+            nn.Linear(128, 10),
+        ).to(device)
+        self.device = device
+
+    def forward(self, x, is_rref=False):
+        x = x.to_here() if is_rref else x
+        with torch.cuda.stream(torch.cuda.current_stream(self.device)):
+            # intentionally adding delay to current CUDA stream
+            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+            return self.net(x)
+
+    def __getstate__(self):
+        # return an empty dict to avoid inspecting the model contents on the
+        # owner
+        return {}
+
+
+class RpcTestCommon:
+    def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None):
+        if mode == RPCExecMode.SYNC:
+            return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs)
+        elif mode == RPCExecMode.ASYNC:
+            return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait()
+        elif mode == RPCExecMode.REMOTE:
+            return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here()
+
+    def _self_py_udf_remote(self, worker_info, x, y, z):
+        rref = rpc.remote(worker_info, my_function, args=(x, y, z))
+        self.assertEqual(rref.to_here(), x + y + z)
+
+    def _self_remote_rref_as_rpc_arg(self, dst, x, y, z):
+        self_worker_info = rpc.get_worker_info()
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x))
+        ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y))
+        self.assertEqual(ret, x + y + z + x + y)
+        self.assertEqual(fut.wait(), x + y + z + x)
+
+    def _self_remote_rref_as_remote_arg(self, dst, x, y, z):
+        self_worker_info = rpc.get_worker_info()
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x))
+        self.assertEqual(ret_rref.to_here(), x + y + z + x)
+
+    def _world_size_one(self, a, b):
+        if self.rank == 0:
+            rpc.init_rpc(
+                name="me",
+                backend=self.rpc_backend,
+                rank=0,
+                world_size=1,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+            def _rpc_sync(x, y):
+                expect = x * 2
+                result = rpc.rpc_sync("me", my_tensor_function, args=(x, y))
+                self.assertEqual(expect, result)
+
+            def _rpc_async(x, y):
+                expect = x * 2
+                result = rpc.rpc_async("me", my_tensor_function, args=(x, y)).wait()
+                self.assertEqual(expect, result)
+
+            def _remote(x, y):
+                expect = x * 2
+                result = rpc.remote("me", my_tensor_function, args=(x, y)).to_here()
+                self.assertEqual(expect, result)
+
+            _rpc_sync(a, b)
+            _rpc_async(a, b)
+            _remote(a, b)
+
+            rpc.shutdown()
+
+    def _multi_rpc(self, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+        for i in range(20):
+            n = i + self.rank + 1
+            if sparse:
+                x = build_sparse_tensor() * n
+                y = build_sparse_tensor() * n
+            else:
+                x = torch.ones(2, 2)
+                y = torch.ones(2, 2)
+            ret = rpc.rpc_sync(
+                worker_name(dst_rank),
+                torch.add,
+                args=(x, y),
+            )
+            self.assertEqual(ret, x * 2)
+
+    def _run_uneven_workload(self, f, x, num_repeat=30):
+        # worker0 drives and waits for worker1 and worker2
+        # throughout the test.
+        if self.rank == 0:
+            self.assertTrue(self.world_size >= 3)
+
+            # Phase 1: Only worker1 has workload.
+            dst = "worker1"
+            futs = []
+            for _ in range(num_repeat):
+                fut = rpc.rpc_async(dst, f, args=(x,))
+                futs.append(fut)
+
+            for fut in torch.futures.collect_all(futs).wait():
+                self.assertEqual(fut.wait(), 0)
+
+            # Phase 2: Only worker2 has workload.
+            # If join is not correctly implemented,
+            # worker2 should be closed by now.
+            dst = "worker2"
+            futs = []
+            for _ in range(num_repeat):
+                fut = rpc.rpc_async(dst, f, args=(x,))
+                futs.append(fut)
+
+            for val in torch.futures.wait_all(futs):
+                self.assertEqual(val, 0)
+
+    def _wait_all_workers(self, f, x):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        self._run_uneven_workload(f, x)
+
+        # worker0 calls this at the end after waiting for RPC responses.
+        # worker1/2 calls this immediately and has some works after it.
+        # worker3 calls this immediately and has no more work.
+        rpc.api._wait_all_workers()
+
+        # Wait before proceeding to shutdown to ensure worker0 RPCs make
+        # it through to other workers.
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+
+    def _wait_all_workers_twice(self, f, x):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        self._run_uneven_workload(f, x)
+
+        # worker0 calls this at the end after waiting for RPC responses.
+        # worker1/2 calls this immediately and has some works after it.
+        # worker3 calls this immediately and has no more work.
+        rpc.api._wait_all_workers()
+        rpc.api._wait_all_workers()
+
+        # Wait before proceeding to shutdown to ensure worker0 RPCs make
+        # it through to other workers.
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+
+    def _nested_rpc(self, f, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            f,
+            args=(worker_name(self.rank),),
+        )
+        self.assertEqual(ret, expected)
+
+    def _stress_test_rpc(self, f, repeat=1000, args=()):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        futs = []
+        tik = time.time()
+        for _ in range(repeat):
+            fut = rpc.rpc_async(worker_name(dst_rank), f, args=args)
+            futs.append(fut)
+
+        for val in torch.futures.wait_all(futs):
+            self.assertEqual(val, 0)
+        tok = time.time()
+        print(
+            f"Rank {self.rank} finished testing {repeat} times in {tok - tik} seconds."
+        )
+
+    def _builtin_remote_ret(self, x, y, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            torch.add,
+            args=(x, y),
+        )
+        self.assertEqual(rref.to_here(), expected)
+
+    def _builtin_remote_self(self, x, y, expected):
+        rref = rpc.remote(
+            worker_name(self.rank),
+            torch.add,
+            args=(x, y),
+        )
+        self.assertEqual(rref.local_value(), expected)
+
+    def _test_multi_remote_call(
+        self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}
+    ):
+        m = 10
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rrefs = []
+        expected = []
+        for i in range(m):
+            n = n + i
+            rrefs.append(
+                rpc.remote(
+                    worker_name(dst_rank),
+                    fn,
+                    args=args_fn(n, sparse),
+                    kwargs=kwargs_fn(n, sparse),
+                )
+            )
+            expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse)))
+
+        for i in range(m):
+            self.assertEqual(rrefs[i].to_here(), expected[i])
+
+    def _py_rref_args(self, a, b, x, y, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(worker_name(dst_rank), torch.add, args=(a, b))
+        rref_b = rpc.remote(worker_name(dst_rank), torch.add, args=(x, y))
+        rref_c = rpc.remote(
+            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), expected)
+
+    def _py_rref_args_user_share(self, a, b, c, x, y, z, expected):
+        n = self.rank + 1
+        owner_rank = n % self.world_size
+        user_rank = (n + 1) % self.world_size
+        rref_a = rpc.remote(worker_name(owner_rank), my_function, args=(a, b, c))
+        rref_b = rpc.remote(worker_name(owner_rank), my_function, args=(x, y, z))
+        rref_c = rpc.remote(
+            worker_name(user_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), expected)
+
+    def _py_rpc_rref_args(self, a, b, c, x, y, z, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(worker_name(dst_rank), my_function, args=(a, b, c))
+        rref_b = rpc.remote(worker_name(dst_rank), my_function, args=(x, y, z))
+
+        c = rpc.rpc_sync(worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b))
+        self.assertEqual(c, expected)
+
+    def _nested_remote(self, f, expected):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+
+        rref = rpc.remote(
+            worker_name(dst_rank1),
+            f,
+            args=(worker_name(dst_rank2),),
+        )
+        self.assertEqual(rref.to_here(), expected)
+
+    def _nested_rref(self, f, expected1, expected2):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        rref_of_rrefs = rpc.remote(
+            worker_name(dst_rank1),
+            f,
+            args=(worker_name(dst_rank2),),
+        )
+
+        # Say C has 2 OwnerRRefs.
+        # B has 2 UserRRefs to those 2 OwnerRRefs, respectively.
+        # This call is effectively A asking B to share its 2 UserRRefs.
+        rrefs = rref_of_rrefs.to_here()
+
+        self.assertEqual(len(rrefs), 2)
+        self.assertEqual(rrefs[0].to_here(), expected1)
+        self.assertEqual(rrefs[1].to_here(), expected2)
+
+    def _nested_rref_stress(self, f, expected1, expected2):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        all_rrefs = [
+            rpc.remote(
+                worker_name(dst_rank1),
+                f,
+                args=(worker_name(dst_rank2),),
+            )
+            for _ in range(20)
+        ]
+
+        for i in range(20):
+            rref_of_rrefs = all_rrefs[i]
+            rrefs = rref_of_rrefs.to_here()
+            self.assertEqual(len(rrefs), 2)
+            self.assertEqual(rrefs[0].to_here(), expected1)
+            self.assertEqual(rrefs[1].to_here(), expected2)
+
+    def _trainer_func(self, rref, sparse):
+        m = MyEmbeddingBagModel(sparse=sparse)
+        loss_fn = nn.MSELoss()
+        for i in range(10):
+            outputs = m(torch.rand(10, 10).long())
+            loss_fn(outputs, torch.rand(10, 10)).backward()
+            gradient = next(iter(m.parameters())).grad
+            fut = rref.rpc_async().average(rref, i, gradient)
+            gradient = fut.wait()
+            if gradient.is_sparse:
+                gradient = gradient.to_dense().double()
+            ps_gradient = rref.rpc_sync().get_gradient(rref)
+            if ps_gradient.is_sparse:
+                ps_gradient = ps_gradient.to_dense().double()
+            self.assertTrue(torch.equal(gradient, ps_gradient))
+
+    def _my_parameter_server(self, sparse):
+        ps_rref = RRef(MyParameterServer(self.world_size - 1))
+        futures = [
+            rpc.rpc_async(
+                worker_name((self.rank + index) % self.world_size),
+                self._trainer_func,
+                args=(ps_rref, sparse),
+            )
+            for index in range(1, self.world_size)
+        ]
+        torch.futures.wait_all(futures)
+
+    def _test_cuda_future_extraction(self, wrapper, unwrapper, sparse_tensor):
+        # We check proper CUDA stream synchronization by adding to the tensor
+        # in one stream to get the expected value, and reading it from another stream.
+        future = Future(devices=["cuda:0"])
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                if sparse_tensor:
+                    tensor = build_sparse_tensor().to("cuda:0")
+                    add_tensor = build_sparse_tensor().to("cuda:0")
+                    expected_tensor = (tensor + add_tensor).coalesce()
+                else:
+                    tensor = torch.zeros((100,), device="cuda:0")
+                    add_tensor = torch.ones((100,), device="cuda:0")
+                    expected_tensor = tensor + add_tensor
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor += add_tensor
+                if sparse_tensor:
+                    tensor = tensor.coalesce()
+                future.set_result(wrapper(tensor))
+            with torch.cuda.stream(another_stream):
+                tensor = unwrapper(future.wait())
+                if sparse_tensor:
+                    self.assertTrue(
+                        torch.eq(tensor.indices(), expected_tensor.indices())
+                        .all()
+                        .item()
+                    )
+                    self.assertTrue(
+                        torch.eq(tensor.values(), expected_tensor.values()).all().item()
+                    )
+                    self.assertEqual(tensor.size(), expected_tensor.size())
+                else:
+                    self.assertTrue(torch.eq(tensor, expected_tensor).all().item())
+
+
+class RpcTest(RpcAgentTestFixture, RpcTestCommon):
+    @dist_init
+    def test_worker_id(self):
+        n = self.rank + 1
+        peer_rank = n % self.world_size
+        self_worker_info = rpc.get_worker_info()
+        peer_worker_info = rpc.get_worker_info(worker_name(peer_rank))
+
+        self.assertEqual(self_worker_info.name, worker_name(self.rank))
+        self.assertEqual(peer_worker_info.name, worker_name(peer_rank))
+
+        with self.assertRaisesRegex(RuntimeError, "could not find destination"):
+            rpc.get_worker_info("WorkerUnknown")
+
+    @dist_init
+    def test_get_worker_infos(self):
+        worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos()
+
+        worker_names = {worker_info.name for worker_info in worker_infos}
+        expected_worker_names = {worker_name(rank) for rank in range(self.world_size)}
+        self.assertEqual(worker_names, expected_worker_names)
+
+        worker_ids = {worker_info.id for worker_info in worker_infos}
+        expected_worker_ids = set(range(self.world_size))
+        self.assertEqual(worker_ids, expected_worker_ids)
+
+    @dist_init
+    def test_self_add(self):
+        self_worker_info = rpc.get_worker_info()
+        fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
+        ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
+        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_send_to_rank(self):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        # Test dense tensor
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            ret = self._run_func_in_mode(
+                dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+            )
+            self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test invalid ranks
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(RuntimeError):
+                self._run_func_in_mode(
+                    self.world_size + 1,
+                    torch.add,
+                    exec_mode,
+                    args=(torch.ones(2, 2), 1),
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(RuntimeError):
+                self._run_func_in_mode(
+                    -1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(ValueError):
+                self._run_func_in_mode(
+                    dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(ValueError):
+                self._run_func_in_mode(
+                    dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+    @dist_init
+    def test_self_py_udf_remote(self):
+        self._self_py_udf_remote(rpc.get_worker_info(), torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_rpc_arg(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_rpc_arg(dst, torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_self_rpc_arg(self):
+        self._self_remote_rref_as_rpc_arg(rpc.get_worker_info(), torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_remote_arg(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_remote_arg(dst, torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(), torch.ones(2, 2), 1, 3
+        )
+
+    @dist_init
+    def test_rref_proxy_non_exist(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
+        msg = "has no attribute 'non_exist'"
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.rpc_sync().non_exist()
+
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.rpc_async().non_exist().wait()
+
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.remote().non_exist()
+
+    def _test_rref_proxy_tensor(self, dst):
+        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
+
+        expected = torch.ones(2, 2) + 1 + 3
+        self.assertEqual(expected.size(), rref.rpc_sync().size())
+        self.assertEqual(expected + 1, rref.rpc_async().add(1).wait())
+        self.assertEqual(expected.view(1, 4), rref.remote().view(1, 4).to_here())
+
+    @dist_init
+    def test_rref_proxy_tensor(self):
+        self._test_rref_proxy_tensor(worker_name((self.rank + 1) % self.world_size))
+
+    @dist_init
+    def test_rref_proxy_tensor_self(self):
+        self._test_rref_proxy_tensor(rpc.get_worker_info())
+
+    @dist_init
+    def test_rref_proxy_reuse(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            my_function,
+            args=(torch.ones(2, 2), 1, 3),
+        )
+        expected = torch.ones(2, 2) + 1 + 3
+
+        proxy_rpc_sync = rref.rpc_sync()
+        proxy_rpc_async = rref.rpc_async()
+        proxy_remote = rref.remote()
+
+        self.assertEqual(expected.size(), proxy_rpc_sync.size())
+        self.assertEqual(expected + 1, proxy_rpc_sync.add(1))
+        self.assertEqual(expected.view(1, 4), proxy_rpc_sync.view(1, 4))
+
+        self.assertEqual(expected.size(), proxy_rpc_async.size().wait())
+        self.assertEqual(expected + 3, proxy_rpc_async.add(3).wait())
+        self.assertEqual(expected.view(4, 1), proxy_rpc_async.view(4, 1).wait())
+
+        self.assertEqual(expected.size(), proxy_remote.size().to_here())
+        self.assertEqual(expected + 5, proxy_remote.add(5).to_here())
+        self.assertEqual(expected.view(-1), proxy_remote.view(-1).to_here())
+
+    def _test_rref_proxy_class(self, dst):
+        rref = rpc.remote(dst, MyClass, args=(7,))
+        expected = MyClass(7)
+        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
+        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
+        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
+
+        expected.increment_value(3)
+        self.assertEqual(None, rref.rpc_sync().increment_value(1))
+        self.assertEqual(None, rref.rpc_async().increment_value(1).wait())
+        self.assertEqual(None, rref.remote().increment_value(1).to_here())
+
+        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
+        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
+        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
+
+        self.assertEqual(
+            expected.my_instance_method(2), rref.rpc_sync().my_instance_method(2)
+        )
+        self.assertEqual(
+            expected.my_instance_method(3),
+            rref.rpc_async().my_instance_method(3).wait(),
+        )
+        self.assertEqual(
+            expected.my_instance_method(4),
+            rref.remote().my_instance_method(4).to_here(),
+        )
+
+        self.assertEqual(
+            expected.my_static_method(9), rref.rpc_sync().my_static_method(9)
+        )
+        self.assertEqual(
+            expected.my_static_method(10), rref.rpc_async().my_static_method(10).wait()
+        )
+        self.assertEqual(
+            expected.my_static_method(11), rref.remote().my_static_method(11).to_here()
+        )
+
+        self.assertEqual(
+            expected.my_class_method(2, torch.zeros(2, 2)),
+            rref.rpc_sync().my_class_method(2, torch.zeros(2, 2)),
+        )
+        self.assertEqual(
+            expected.my_class_method(2, torch.ones(3, 3)),
+            rref.rpc_async().my_class_method(2, torch.ones(3, 3)).wait(),
+        )
+        self.assertEqual(
+            expected.my_class_method(2, torch.ones(4, 4)),
+            rref.remote().my_class_method(2, torch.ones(4, 4)).to_here(),
+        )
+
+    @dist_init
+    def test_rref_proxy_class(self):
+        self._test_rref_proxy_class(worker_name((self.rank + 1) % self.world_size))
+
+    @dist_init
+    def test_rref_proxy_class_self(self):
+        self._test_rref_proxy_class(rpc.get_worker_info())
+
+    @mock.patch.object(torch.distributed.autograd, "_init")
+    @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent")
+    @dist_init(setup_rpc=False)
+    def test_register_rpc_backend_and_set_and_start_rpc_backend(
+        self, mock_rpc_agent, mock_dist_autograd_init
+    ):
+        backend_name = "stub_backend"
+
+        backend = rpc.backend_registry.register_backend(
+            backend_name,
+            _stub_construct_rpc_backend_options_handler,
+            _stub_init_rpc_backend_handler,
+        )
+
+        with self.assertRaisesRegex(
+            RuntimeError, "^RPC backend .+: already registered$"
+        ):
+            backend = rpc.backend_registry.register_backend(
+                backend_name,
+                _stub_construct_rpc_backend_options_handler,
+                _stub_init_rpc_backend_handler,
+            )
+
+        rpc.init_rpc(
+            name="worker1",
+            backend=backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+    @dist_init(setup_rpc=False)
+    def test_duplicate_name(self):
+        with self.assertRaisesRegex(RuntimeError, "is not unique"):
+            store, _, _ = next(
+                torch.distributed.rendezvous(
+                    self.init_method, rank=self.rank, world_size=self.world_size
+                )
+            )
+            rpc._init_rpc_backend(
+                backend=self.rpc_backend,
+                store=store,
+                name="duplicate_name",
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_duplicate_name_2(self):
+        with self.assertRaisesRegex(RuntimeError, "is not unique"):
+            rpc.init_rpc(
+                name=worker_name(self.rank % (self.world_size - 1)),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_reinit(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        # Wait for all init to complete.
+        dist.barrier()
+
+        # TODO: with TCP init, rank 0 raises Address already in use because
+        # rank 0 is the start daemon and the store is created before checking if
+        # RPC is already initialized in init_rpc.
+        if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0:
+            expected_reinit_err = "Address already in use"
+        else:
+            expected_reinit_err = "is already initialized"
+
+        with self.assertRaisesRegex(RuntimeError, expected_reinit_err):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    def test_pg_init_no_rpc_init(self):
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.file_init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        class MyModel(torch.nn.Module):
+            def __init__(self) -> None:
+                super().__init__()
+                self.lin = torch.nn.Linear(3, 4)
+
+            def forward(self, x):
+                return self.lin(x)
+
+        model = MyModel()
+        model.train()
+        model = torch.nn.parallel.DistributedDataParallel(model)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Current RPC agent is not set! Did you initialize the RPC framework",
+        ):
+            [RRef(param) for param in model.parameters()]
+
+    def test_world_size_one(self):
+        self._world_size_one(torch.ones(2, 2), torch.ones(2, 2))
+
+    @dist_init(setup_rpc=False)
+    def test_invalid_names(self):
+        worker_id = 0
+        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
+            WorkerInfo("abc*", worker_id)
+
+        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
+            WorkerInfo(" ", worker_id)
+
+        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
+            WorkerInfo("", worker_id)
+
+        # If the number in the message does not match, it is likely that the
+        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
+        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
+            WorkerInfo("".join(["a" for i in range(500)]), worker_id)
+
+    # Test that WorkerInfo can be pickled and sent in RPC call
+    @dist_init
+    def test_worker_info_pickle(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        worker_info = rpc.api.get_worker_info()
+        ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,))
+        self.assertEqual(ret, worker_info)
+
+    @dist_init
+    def test_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+
+    @staticmethod
+    def return_callee_id():
+        return rpc.get_worker_info().id
+
+    @dist_init
+    def test_int_callee(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id)
+        self.assertEqual(ret, dst_rank)
+
+    @dist_init
+    def test_add_with_id(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        workder_info = rpc.get_worker_info(worker_name(dst_rank))
+
+        ret = rpc.rpc_sync(
+            workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_scalar_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), n))
+        self.assertEqual(ret, (torch.ones(n, n) + n))
+
+    @dist_init
+    def test_async_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        fut = rpc.rpc_async(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_nonzero(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        x = torch.ones(self.world_size, self.world_size)
+        x[self.rank][self.rank] = 0
+        ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,))
+        self.assertEqual(ret, x.nonzero())
+
+    @dist_init
+    def test_multi_rpc(self):
+        self._multi_rpc(False)
+
+    @dist_init
+    def test_future_wait_twice(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [rpc.rpc_async(dst, raise_func) for _ in range(20)]
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+        for fut in futs:
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut.wait()
+
+    @dist_init(setup_rpc=False)
+    def test_wait_all_workers_timeout(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        og_func = rpc.api._wait_all_workers
+
+        def wait_all_workers_sleep(timeout):
+            rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
+
+        rpc.api._wait_all_workers = wait_all_workers_sleep
+
+        try:
+            with self.assertRaisesRegex(RuntimeError, ""):
+                rpc.shutdown(graceful=True, timeout=0.01)
+        finally:
+            rpc.api._wait_all_workers = og_func
+        dist.barrier()
+
+    def test_wait_all_workers_dense(self):
+        self._wait_all_workers(heavy_rpc, torch.ones(100, 100))
+
+    def test_wait_all_workers_twice_dense(self):
+        self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100))
+
+    @dist_init
+    def test_all_gather(self):
+        info = rpc.get_worker_info()
+        results = rpc.api._all_gather(info.id)
+        expected = {}
+        for info in rpc._get_current_rpc_agent().get_worker_infos():
+            expected[info.name] = info.id
+
+        self.assertEqual(expected, results)
+
+    @dist_init
+    def test_all_gather_timeout(self):
+        rpc._set_rpc_timeout(0.1)
+
+        if self.rank == 0:
+            with self.assertRaisesRegex(
+                RuntimeError, "timed out in _all_gather after 0\\.10 seconds"
+            ):
+                rpc.api._all_gather(SlowPickleClass(0.5))
+        else:
+            expected_error = self.get_timeout_error_regex()
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                rpc.api._all_gather(SlowPickleClass(0.5))
+
+    def _test_barrier_helper(self, info, names, multi_threaded=False):
+        names = sorted(names)
+        leader = names[0]
+        rpc.rpc_sync(leader, _reset_count)
+        if not multi_threaded and info.name == leader:
+            self.assertEqual(_rpc_barrier_count, 0)
+        rpc.api._barrier(names)
+        rpc.rpc_sync(leader, _increment_count)
+        rpc.api._barrier(names)
+        if not multi_threaded and info.name == leader:
+            self.assertEqual(_rpc_barrier_count, len(names))
+
+    @dist_init
+    def test_rpc_barrier_all(self):
+        # Test rpc barrier when called with full list of workers
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        names = [worker.name for worker in all_worker_info]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_subset(self):
+        # Test rpc barrier when processes are called with different subsets of the full list
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        if info.id % 2:
+            names = [worker.name for worker in all_worker_info if worker.id % 2]
+        else:
+            names = [worker.name for worker in all_worker_info if not worker.id % 2]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_partial_subset(self):
+        # Test rpc barrier when some processes are not involved in the barrier
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        if info.id % 2:
+            names = [worker.name for worker in all_worker_info if worker.id % 2]
+        else:
+            names = [f"worker{info.id}"]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_multithreaded(self):
+        # This tests validates the implementation of barrier when multiple threads call into it
+        # We only need to check that it does not hang in this case
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        names = [worker.name for worker in all_worker_info]
+        threads = []
+        for _ in range(3):
+            th = threading.Thread(
+                target=self._test_barrier_helper, args=(info, names, True)
+            )
+            threads.append(th)
+            th.start()
+        for th in threads:
+            th.join()
+
+    @dist_init
+    def test_graceful_shutdown_with_uneven_workload(self):
+        """Test graceful termination."""
+        self._run_uneven_workload(heavy_rpc, torch.ones(100, 100))
+
+    @dist_init(setup_rpc=False)
+    def test_shutdown_followed_by_rpc(self):
+        # Initialize RPC.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+        rpc.shutdown()
+
+        with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
+            rpc.rpc_sync(
+                worker_name(dst_rank),
+                torch.add,
+                args=(torch.ones(n, n), torch.ones(n, n)),
+            )
+
+    @dist_init
+    def test_expected_src(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        expected_src_rank = (self.rank - 1) % self.world_size
+        rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,))
+        value = VALUE_FUTURE.result()
+        self.assertEqual(value, expected_src_rank)
+
+    @dist_init
+    def test_py_built_in(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), min, args=(n, n + 1, n + 2))
+        self.assertEqual(ret, min(n, n + 1, n + 2))
+
+    @dist_init
+    def test_py_user_defined(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            my_function,
+            kwargs={"a": n, "b": n + 1, "c": n + 2},
+        )
+        self.assertEqual(ret, my_function(n, n + 1, n + 2))
+
+    def test_build_rpc_profiling_key(self):
+        # Tests that the name that shows up as an Event in profiling RPCs has all
+        # the necessary information.
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            rpc_profiling_key = _build_rpc_profiling_key(
+                exec_mode, "foo", "worker0", "worker1"
+            )
+            self.assertIn(exec_mode.value, rpc_profiling_key)
+            self.assertIn("foo", rpc_profiling_key)
+            self.assertIn("worker0", rpc_profiling_key)
+            self.assertIn("worker1", rpc_profiling_key)
+
+    def check_profiling_info(
+        self, self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode
+    ):
+        self.assertTrue(self_worker_name in rpc_event.name)
+        self.assertTrue(dst_worker_name in rpc_event.name)
+        if isinstance(func, torch.jit.ScriptFunction):
+            self.assertTrue(torch._jit_internal._qualified_name(func) in rpc_event.name)
+        else:
+            self.assertTrue(func.__name__ in rpc_event.name)
+        self.assertTrue(rpc_exec_mode.value in rpc_event.name)
+        self.assertEqual(rpc_event.count, 1)
+
+    @dist_init
+    def test_profiler_rpc_record_shapes(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        t1, t2 = torch.ones(100), torch.ones(100)
+        with _profile(record_shapes=True) as prof:
+            rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2))
+
+        function_events = prof.function_events
+        remote_events = [event for event in function_events if event.is_remote]
+        remote_add_event = next(
+            event for event in remote_events if "aten::add" in event.name
+        )
+        remote_add_input_shapes = remote_add_event.input_shapes
+        # Run profiler on equivalent local op and validate shapes are the same.
+        with _profile(record_shapes=True) as prof:
+            torch.add(t1, t2)
+
+        local_function_events = prof.function_events
+        local_add_event = next(
+            event for event in local_function_events if "aten::add" in event.name
+        )
+        local_add_input_shapes = local_add_event.input_shapes
+        self.assertEqual(remote_add_input_shapes, local_add_input_shapes)
+
+    @dist_init
+    def test_profiler_rpc_memory(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        with _profile(profile_memory=True) as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        function_events = p.function_events
+        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
+        # if cpu_memory_usage was not propagated over the wire, this set would
+        # only contain 0 (indicates no memory being profiled)
+        self.assertNotEqual({0}, event_cpu_mem_usages)
+        # No memory profiled if profile_memory=False
+        with _profile(profile_memory=False) as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        function_events = p.function_events
+        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
+        self.assertEqual({0}, event_cpu_mem_usages)
+
+    @dist_init
+    def test_profiler_export_trace(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        with _profile() as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        with TemporaryFileName() as fname:
+            path = fname
+            p.export_chrome_trace(path)
+            with open(path) as f:
+                trace = json.load(f)
+                event_names = [event["name"] for event in trace]
+                for expected_event_name in EXPECTED_REMOTE_EVENTS + [
+                    RPCExecMode.ASYNC.value
+                ]:
+                    event_exists = any(
+                        expected_event_name in event_name for event_name in event_names
+                    )
+                    self.assertTrue(event_exists)
+
+    @dist_init
+    def test_profiler_rpc_key_names(self):
+        # tests that remote events are properly prefixed with the RPC profiling key.
+        if self.rank != 1:
+            return
+
+        # Spawn multiple threads that send RPCs to ensure keys are correctly
+        # prefixed when there are multiple RPCs being created/in flight at the
+        # same time.
+        dst_ranks = [rank for rank in range(self.world_size) if rank != self.rank]
+
+        def rpc_with_profiling(dst_worker):
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+                fut.wait()
+
+            events = prof.function_events
+            remote_event_names = {
+                event.name: event for event in events if event.is_remote
+            }
+            rpc_profiling_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC,
+                udf_with_torch_ops.__qualname__,
+                worker_name(self.rank),
+                dst_worker,
+            )
+
+            remote_event_name_set = set(EXPECTED_REMOTE_EVENTS)
+            for name, event in remote_event_names.items():
+                # Ensure that we have the expected key as part of the remote
+                # event.
+                self.assertTrue(name.startswith(rpc_profiling_key))
+                self.assertTrue(event.is_remote)
+                self.assertTrue(event.node_id == rpc.get_worker_info(dst_worker).id)
+                # Ensure that the remote event name also contains the operator.
+                operator_name_substr = name[len(rpc_profiling_key) :]
+                # Note: we don't assert that every remote event needs to be
+                # in the above set, the set is just a representative set of
+                # what we expect to see. The profiler can change and add more
+                # events, but we should always expect to see this representative
+                # set.
+                matching_event = {
+                    remote_event_name
+                    for remote_event_name in remote_event_name_set
+                    if remote_event_name in operator_name_substr
+                }
+                remote_event_name_set -= matching_event
+
+            # The set should be empty, otherwise its contained elements did
+            # not show up in the remote profiler output.
+            self.assertTrue(
+                remote_event_name_set == set(),
+                f"Expected {remote_event_name_set} to be included in remote profiler output.",
+            )
+
+        for dst in dst_ranks:
+            dst_worker = worker_name(dst)
+            num_parallel_rpcs = 2
+            with concurrent.futures.ThreadPoolExecutor(
+                max_workers=num_parallel_rpcs
+            ) as executor:
+                futs = [
+                    executor.submit(rpc_with_profiling, dst_worker)
+                    for _ in range(num_parallel_rpcs)
+                ]
+                # Wait for workers to finish test
+                for fut in futs:
+                    fut.result()
+
+    def _run_test_profiler_remote_events_profiled(self):
+        # Tests that we can successfully invoke the profiler on a remote node,
+        # and collect the remote events back in the local profiler.
+        if self.rank != 1:
+            return
+
+        dst_ranks = [rank for rank in range(self.world_size) if rank != self.rank]
+        for dst in dst_ranks:
+            dst_worker = worker_name(dst)
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+                fut.wait()
+
+            events = prof.function_events
+
+            rpc_event = get_function_event(events, RPCExecMode.ASYNC.value)
+            self.check_profiling_info(
+                worker_name(self.rank),
+                dst_worker,
+                udf_with_torch_ops,
+                rpc_event,
+                RPCExecMode.ASYNC,
+            )
+
+            remote_events = {event.name: event for event in events if event.is_remote}
+            rpc_profiling_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC,
+                udf_with_torch_ops.__qualname__,
+                worker_name(self.rank),
+                worker_name(dst),
+            )
+
+            for expected_remote_event_name in EXPECTED_REMOTE_EVENTS:
+                expected_key = (
+                    rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name
+                )
+                self.assertTrue(expected_key in remote_events)
+                remote_event = remote_events[expected_key]
+                # Remote event should have a node ID corresponding to the worker
+                # it ran on.
+                self.assertEqual(remote_event.node_id, dst)
+
+            # Validate order remote events show up in profiling output.
+            def convert_remote_to_local(event_name):
+                remote_op_key = rpc_profiling_key + REMOTE_OP_STR
+                return event_name[event_name.find(remote_op_key) + len(remote_op_key) :]
+
+            remote_events_list = [
+                convert_remote_to_local(event.name)
+                for event in events
+                if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS
+            ]
+            self.assertEqual(
+                set(remote_events_list),
+                set(EXPECTED_REMOTE_EVENTS),
+                f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}",
+            )
+
+    @dist_init
+    def test_profiler_remote_events_profiled(self):
+        self._run_test_profiler_remote_events_profiled()
+
+    @dist_init
+    def test_profiler_remote_events_profiled_single_threaded(self):
+        self._run_test_profiler_remote_events_profiled()
+
+    def run_profiling_workload(self, dst):
+        fut = rpc.rpc_async(
+            worker_name(dst),
+            torch.mul,
+            args=(
+                torch.tensor(1.0, requires_grad=True),
+                torch.tensor(1.0, requires_grad=True),
+            ),
+        )
+        fut.wait()
+
+    def _run_rpc_profiling_async_function(self, device="cpu"):
+        if self.rank != 1:
+            return
+
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+        x = torch.ones(2)
+        y = torch.ones(2)
+        with _profile() as prof:
+            ret = rpc.rpc_async(
+                dst1, slow_async_add, args=(dst2, x, y, device), timeout=20
+            )
+            ret.wait()
+
+        function_events = prof.function_events
+        # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be
+        # recorded.
+        key_prefix = _build_rpc_profiling_key(
+            RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1
+        )
+
+        nested_rpc_key_prefix = _build_rpc_profiling_key(
+            RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2
+        )
+        expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix
+        remote_events = [event for event in function_events if event.is_remote]
+        rpc_remote_event = [
+            event for event in remote_events if event.name == expected_key
+        ]
+        self.assertEqual(1, len(rpc_remote_event))
+        rpc_remote_event = rpc_remote_event[0]
+        self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size)
+        # slow_async_add's RPC does an add on dst2, which should be reflected as well.
+        remote_add_key = (
+            expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add)
+        )
+        remote_add_event = [
+            event for event in remote_events if event.name == remote_add_key
+        ]
+        self.assertEqual(1, len(remote_add_event))
+        remote_add_event = remote_add_event[0]
+        # Validate that node_id is dst2.
+        self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size)
+
+    @dist_init
+    def test_rpc_profiling_async_function(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        self._run_rpc_profiling_async_function()
+        if torch.cuda.is_available():
+            dist.barrier()
+            self._run_rpc_profiling_async_function(device="cuda:0")
+
+    @dist_init
+    def test_rpc_profiling_async_function_single_threaded(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        self._run_rpc_profiling_async_function()
+        if torch.cuda.is_available():
+            dist.barrier()
+            self._run_rpc_profiling_async_function(device="cuda:0")
+
+    @dist_init
+    def test_rpc_profiling_remote_record_function(self):
+        # test that functions run over RPC with record_function show the expected
+        # profiled block.
+        if self.rank != 1:
+            return
+        dst_ranks = [i for i in range(self.world_size) if i != self.rank]
+        for dst_rank in dst_ranks:
+            dst_worker = worker_name(dst_rank)
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=(-1, True))
+                fut.wait()
+
+            function_events = prof.function_events
+            record_function_remote_event = [
+                evt for evt in function_events if "##forward##" in evt.name
+            ]
+            self.assertEqual(1, len(record_function_remote_event))
+            record_function_remote_event = record_function_remote_event[0]
+            self.assertEqual(record_function_remote_event.node_id, dst_rank)
+            # cpu_children only returns direct children, so here we get all
+            # children recursively.
+
+            def get_cpu_children(event):
+                if not event.cpu_children:
+                    return []
+                cpu_children = event.cpu_children
+                for e in event.cpu_children:
+                    cpu_children.extend(get_cpu_children(e))
+                return cpu_children
+
+            remote_children = get_cpu_children(record_function_remote_event)
+            # Get local children and verify parity.
+            with _profile() as prof:
+                udf_with_torch_ops(-1, True)
+
+            local_function_events = prof.function_events
+            local_record_function_event = next(
+                evt for evt in local_function_events if "##forward##" in evt.name
+            )
+            local_children = get_cpu_children(local_record_function_event)
+            local_children_names = [evt.name for evt in local_children]
+
+            REMOTE_OP_STR = "#remote_op: "
+
+            def convert_remote_to_local(event_name):
+                remote_op_key = REMOTE_OP_STR
+                return event_name[event_name.find(remote_op_key) + len(remote_op_key) :]
+
+            for evt in remote_children:
+                local_name = convert_remote_to_local(evt.name)
+                self.assertTrue(local_name in local_children_names)
+
+    def validate_profiling_workload(self, dst, prof):
+        def convert_remote_to_local(event_name):
+            return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :]
+
+        events = prof.function_events
+        remote_events = {
+            convert_remote_to_local(event.name): event
+            for event in events
+            if event.is_remote
+        }
+        self.assertTrue("aten::mul" in remote_events)
+        remote_mul_event = remote_events["aten::mul"]
+        self.assertEqual(remote_mul_event.node_id, dst)
+        self.check_profiling_info(
+            worker_name(self.rank),
+            worker_name(dst),
+            torch.mul,
+            remote_mul_event,
+            RPCExecMode.ASYNC,
+        )
+
+    def _run_test_profiler_with_autograd_context(self):
+        dst = (self.rank + 1) % self.world_size
+        if self.rank == 1:
+            # Cases where we can double wrap messages with profiling information and autograd info.
+            with dist_autograd.context(), _profile() as prof:
+                self.run_profiling_workload(dst)
+
+            self.validate_profiling_workload(dst, prof)
+
+            # Ensure that flipped order of ctx managers results in events being
+            # recorded as expected.
+            with _profile() as prof, dist_autograd.context():
+                self.run_profiling_workload(dst)
+
+            self.validate_profiling_workload(dst, prof)
+
+    @dist_init
+    def test_profiler_with_autograd_context_single_threaded(self):
+        self._run_test_profiler_with_autograd_context()
+
+    @dist_init
+    def test_profiler_with_autograd_context(self):
+        self._run_test_profiler_with_autograd_context()
+
+    def _profiler_test_with_rpc(
+        self,
+        rpc_exec_mode,
+        func,
+        args,
+        use_record_function=False,
+        dst=None,
+        kineto_profile=False,
+    ):
+        dst = dst if dst is not None else (self.rank + 1) % self.world_size
+
+        # only run profiler on rank 1.
+        p = _profile if not kineto_profile else torch.profiler.profile  # kineto
+        if self.rank == 1:
+            with p() as prof:
+                record_function_ctx_mgr = (
+                    contextlib.nullcontext()
+                    if not use_record_function
+                    else torch.autograd.profiler.record_function("foo")
+                )
+                with record_function_ctx_mgr:
+                    if rpc_exec_mode == RPCExecMode.SYNC:
+                        rpc.rpc_sync(worker_name(dst), func, args=args)
+                    elif rpc_exec_mode == RPCExecMode.ASYNC:
+                        fut = rpc.rpc_async(worker_name(dst), func, args=args)
+                        if kineto_profile:
+                            # Ensure multiple async RPCs don't cause issues.
+                            # Would have raised
+                            # "RuntimeError: Cannot call
+                            # RemoteProfilerManager::setCurrentKey when current
+                            # key is already set." error if RPC profiling was
+                            # not disabled properly for kineto.
+                            fut2 = rpc.rpc_async(worker_name(dst), func, args=args)
+                            fut2.wait()
+                        fut.wait()
+                    else:
+                        self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE)
+                        rref = rpc.remote(worker_name(dst), func, args=args)
+                        rref.to_here()
+                        # To avoid flakiness, wait for the RRef to be profiled. This
+                        # means that we received the acknowledgement of successful
+                        # creation on the owner and ran the callbacks responsible
+                        # for recording the profiling event.
+                        rref._get_profiling_future().wait()
+
+            events = prof.function_events if not kineto_profile else prof.events()
+            if kineto_profile:
+                # RPC profiling is disabled so there should be no rpc related
+                # events.
+                with self.assertRaises(IndexError):
+                    get_function_event(events, rpc_exec_mode.value)
+
+                return
+
+            rpc_event = get_function_event(events, rpc_exec_mode.value)
+            # verify Node ID for this rpc event.
+            self.assertEqual(rpc_event.node_id, self.rank)
+            # Ensure recording of remote events.
+            remote_events = {event for event in events if event.node_id == dst} - {
+                rpc_event
+            }
+            self.assertGreaterEqual(len(remote_events), 1)
+            for remote_event in remote_events:
+                self.assertEqual(remote_event.node_id, dst)
+
+            if use_record_function:
+                scope_event = get_function_event(events, "foo")
+                # Since RPC call is within the scope, its CPU interval should be
+                # contained within foo's interval.
+                self.assertLessEqual(
+                    scope_event.time_range.start, rpc_event.time_range.start
+                )
+                self.assertGreaterEqual(
+                    scope_event.time_range.end, rpc_event.time_range.end
+                )
+            # the sender, dest worker, function run, and type of RPC should all
+            # be recorded.
+            self_worker_name = worker_name(self.rank)
+            dst_worker_name = worker_name(dst)
+            self.check_profiling_info(
+                self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode
+            )
+            if use_record_function:
+                # verify order by ensuring that the outer context comes
+                # before the rpc event.
+                foo_event_ix = next(
+                    i for i, event in enumerate(events) if "foo" in event.name
+                )
+                rpc_event_idx = next(
+                    i
+                    for i, event in enumerate(events)
+                    if rpc_exec_mode.value in event.name
+                )
+                self.assertLess(foo_event_ix, rpc_event_idx)
+
+    def _run_test_profiler_with_sync_rpc_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, my_sleep_func, args=(1,), use_record_function=True
+        )
+
+    @dist_init
+    def test_profiler_with_sync_rpc_udf(self):
+        self._run_test_profiler_with_sync_rpc_udf()
+
+    @dist_init
+    def test_profiler_with_sync_rpc_udf_single_threaded(self):
+        self._run_test_profiler_with_sync_rpc_udf()
+
+    def _run_test_profiler_with_sync_rpc_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_sync_rpc_builtin(self):
+        self._run_test_profiler_with_sync_rpc_builtin()
+
+    @dist_init
+    def test_profiler_with_sync_rpc_builtin_single_threaded(self):
+        self._run_test_profiler_with_sync_rpc_builtin()
+
+    def _run_test_profiler_with_async_rpc_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_sleep_func, args=(1,), use_record_function=True
+        )
+        # Test to ensure that kineto profiler enabled in RPC does not enable
+        # RPC profiling (it is unsupported) and does not result in issues.
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_sleep_func, args=(1,), kineto_profile=True
+        )
+
+    @dist_init
+    def test_profiler_with_async_rpc_udf(self):
+        self._run_test_profiler_with_async_rpc_udf()
+
+    @dist_init
+    def test_profiler_with_async_rpc_udf_single_threaded(self):
+        self._run_test_profiler_with_async_rpc_udf()
+
+    def _run_test_profiler_with_async_rpc_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_async_rpc_builtin(self):
+        self._run_test_profiler_with_async_rpc_builtin()
+
+    @dist_init
+    def test_profiler_with_async_rpc_builtin_single_threaded(self):
+        self._run_test_profiler_with_async_rpc_builtin()
+
+    def _run_test_profiler_with_remote_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_sleep_func, args=(1,), dst=self.rank
+        )
+
+    @dist_init
+    def test_profiler_with_remote_udf(self):
+        self._run_test_profiler_with_remote_udf()
+
+    @dist_init
+    def test_profiler_with_remote_udf_single_threaded(self):
+        self._run_test_profiler_with_remote_udf()
+
+    def _run_test_profiler_with_remote_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            dst=self.rank,
+        )
+
+    @dist_init
+    def test_profiler_with_remote_builtin(self):
+        self._run_test_profiler_with_remote_builtin()
+
+    @dist_init
+    def test_profiler_with_remote_builtin_single_threaded(self):
+        self._run_test_profiler_with_remote_builtin()
+
+    def _run_test_profiler_with_script_async_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_script_async_rpc(self):
+        self._run_test_profiler_with_script_async_rpc()
+
+    @dist_init
+    def test_profiler_with_script_async_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_async_rpc()
+
+    def _run_test_profiler_with_script_sync_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_script_sync_rpc(self):
+        self._run_test_profiler_with_script_sync_rpc()
+
+    @dist_init
+    def test_profiler_with_script_sync_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_sync_rpc()
+
+    def _run_test_profiler_with_script_remote_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank
+        )
+
+    @dist_init
+    def test_profiler_with_script_remote_rpc(self):
+        self._run_test_profiler_with_script_remote_rpc()
+
+    @dist_init
+    def test_profiler_with_script_remote_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_remote_rpc()
+
+    def _assert_top_level_events(
+        self, process_global_events, expected_top_level_event_names
+    ):
+        top_level_event_names = []
+        for thread_local_events in process_global_events:
+            # Get top-level events from all events happened on a thread.
+            last_end_time = 0
+            for event in thread_local_events:
+                event_name = event.name
+                time_range = event.time_range
+                if time_range.start > last_end_time:
+                    top_level_event_names.append(event_name)
+                    last_end_time = time_range.end
+        top_level_event_names = sorted(top_level_event_names)
+        expected_top_level_event_names = sorted(expected_top_level_event_names)
+        self.assertEqual(
+            top_level_event_names,
+            expected_top_level_event_names,
+            f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}",
+        )
+
+    @dist_init
+    def test_server_process_global_profiler(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker_name = worker_name(dst_rank)
+
+        x = torch.tensor(1)
+        y = torch.tensor(2)
+
+        outer_profile_rref = rpc.remote(
+            dst_worker_name, rpc._server_process_global_profile
+        )
+        outer_profile_rref.rpc_sync().__enter__()
+        rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
+        inner_profile_rref = rpc.remote(
+            dst_worker_name, rpc._server_process_global_profile
+        )
+        inner_profile_rref.rpc_sync().__enter__()
+        rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
+        inner_profile_rref.rpc_sync().__exit__(None, None, None)
+        outer_profile_rref.rpc_sync().__exit__(None, None, None)
+
+        inner_events = rpc.rpc_sync(
+            dst_worker_name, get_events_from_profile, (inner_profile_rref,)
+        )
+        expected_inner_events = ["aten::sub"]
+        expected_outer_events = expected_inner_events + ["aten::add"]
+
+        self._assert_top_level_events(inner_events, expected_inner_events)
+        outer_events = rpc.rpc_sync(
+            dst_worker_name, get_events_from_profile, (outer_profile_rref,)
+        )
+        self._assert_top_level_events(outer_events, expected_outer_events)
+
+        inner_profile_rref.rpc_sync().key_averages()
+        outer_profile_rref.rpc_sync().key_averages()
+
+    @dist_init
+    def test_async_record_function_double_end_callbacks(self):
+        num_sleep_seconds = 1
+        if self.rank == 1:
+            # Validate that calling the function twice results in an error.
+            with _profile():
+                with torch.autograd.profiler.record_function("foo") as rf:
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+                    )
+                    rf._call_end_callbacks_on_future(fut)
+                    with self.assertRaisesRegex(
+                        RuntimeError, "can only be called once."
+                    ):
+                        rf._call_end_callbacks_on_future(fut)
+                fut.wait()
+
+    @dist_init
+    def test_async_record_function_legacy(self):
+        # Test the legacy _record_function ops work
+        # Note: These exist for backward compatibility with TorchScript
+        num_sleep_seconds = 1
+        if self.rank == 1:
+            with _profile():
+                try:
+                    handle = torch.ops.profiler._record_function_enter("foo", None)
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+                    )
+                    torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
+                finally:
+                    torch.ops.profiler._record_function_exit(handle)
+
+                fut.wait()
+
+    @dist_init
+    def test_async_record_function_cbs_jit_call(self):
+        if self.rank == 1:
+            with _profile() as pf:
+                key = _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC,
+                    torch._jit_internal._qualified_name(my_script_func),
+                    "worker1",
+                    "worker0",
+                )
+                with torch.autograd.profiler.record_function(key) as rf:
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_script_func, args=(torch.tensor(1),)
+                    )
+                    # Intentionally calling record_function internals
+                    fut = torch.ops.profiler._call_end_callbacks_on_jit_fut(
+                        rf.record, fut
+                    )
+                result = fut.wait()
+                # Validate that the profiling future returns the same value as the RPC
+                # future.
+                expected = torch.add(torch.tensor(1), torch.tensor(1))
+                self.assertEqual(result, expected)
+            events = pf.function_events
+            rpc_event = get_function_event(
+                events, torch._jit_internal._qualified_name(my_script_func)
+            )
+            self.assertTrue(
+                torch._jit_internal._qualified_name(my_script_func) in rpc_event.name
+            )
+
+    @dist_init
+    def test_py_class_constructor(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), MyClass, args=(n,))
+        self.assertEqual(ret.a, n)
+
+    @dist_init
+    def test_py_class_instance_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass(2).my_instance_method, args=(n,)
+        )
+        self.assertEqual(ret, MyClass(2).my_instance_method(n))
+
+    @dist_init
+    def test_py_class_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass.my_class_method, args=(n, n + 1)
+        )
+        self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
+
+    @dist_init
+    def test_py_class_static_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass.my_static_method, args=(n + 10,)
+        )
+        self.assertEqual(ret, MyClass.my_static_method(n + 10))
+
+    @dist_init
+    def test_py_multi_async_call(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        dst_worker_info = rpc.get_worker_info(worker_name(dst_rank))
+        fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
+        fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
+        self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
+        self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
+
+    @dist_init
+    def test_py_no_return_result(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), no_result)
+        self.assertEqual(ret, no_result())
+
+    @dist_init
+    def test_py_tensors(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            my_tensor_function,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
+
+    @dist_init
+    def test_py_tensors_multi_async_call(self):
+        futs = []
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        for i in range(100):
+            fut = rpc.rpc_async(
+                worker_name(dst_rank),
+                my_tensor_function,
+                args=(torch.ones(i, i), torch.ones(i, i)),
+            )
+            futs.append(fut)
+
+        for j, val in enumerate(torch.futures.wait_all(futs)):
+            self.assertEqual(
+                val, my_tensor_function(torch.ones(j, j), torch.ones(j, j))
+            )
+
+    @dist_init
+    def test_py_tensors_in_container(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        a = [torch.ones(n, n), torch.ones(n, n)]
+        b = TensorClass(build_complex_tensors())
+        c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), my_complex_tensor_function, args=(a, b, c)
+        )
+        self.assertEqual(ret, my_complex_tensor_function(a, b, c))
+
+    @dist_init
+    def test_py_nested_pickle(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            run_nested_pickle,
+            args=(MyPickleClass(), torch.ones(2, 2)),
+        )
+
+        m = MyPickleClass()
+        m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
+        self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
+
+    @dist_init
+    def test_py_function_exception(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        with self.assertRaises(TypeError):
+            rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,))
+
+    @dist_init
+    def test_py_raise_in_user_func(self):
+        with captured_output() as (_, err):
+            # This barrier prevents a race condition where the main thread has
+            # not entered the context manager when the remote function runs.
+            initialize_pg(self.file_init_method, self.rank, self.world_size)
+            dist.barrier()
+            n = self.rank + 1
+            dst_rank = n % self.world_size
+            fut = rpc.rpc_async(worker_name(dst_rank), raise_func)
+            with self.assertRaisesRegex(ValueError, expected_err):
+                fut.wait()
+            # This barrier prevents a race condition where the main thread exits
+            # context manager before the remote function has ran.
+            dist.barrier()
+
+        # Validate that trainers log errors when running functions.
+        stderr_lines = err.getvalue()
+        self.assertTrue(expected_err in stderr_lines)
+
+    @dist_init
+    def test_py_raise_in_user_func_escaped_str(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape)
+        try:
+            fut.wait()
+        except ValueError as e:
+            msg = str(e)
+            # Ensure newlines are unescaped to provide a better repr of error.
+            self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape"))
+        else:
+            self.assertTrue(False, "expected raise_func_escape to raise ValueError.")
+
+    @dist_init
+    def test_nested_rpc(self):
+        self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_stress_light_rpc(self):
+        self._stress_test_rpc(light_rpc)
+
+    @dist_init
+    def test_stress_heavy_rpc(self):
+        self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
+
+    @dist_init
+    def test_stress_heavy_rpc_torchscript(self):
+        self._stress_test_rpc(
+            heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),)
+        )
+
+    @dist_init
+    def test_builtin_remote_ret(self):
+        self._builtin_remote_ret(
+            torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_self(self):
+        self._builtin_remote_self(
+            torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2
+        )
+
+    @staticmethod
+    def _multi_args_fn(n, sparse=False):
+        if sparse:
+            return (build_sparse_tensor(), build_sparse_tensor())
+        else:
+            return (torch.ones(n, n), torch.ones(n, n))
+
+    @dist_init
+    def test_multi_builtin_remote_ret(self):
+        self._test_multi_remote_call(torch.add, False, args_fn=RpcTest._multi_args_fn)
+
+    @dist_init
+    def test_py_udf_remote(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            my_function,
+            kwargs={"a": n, "b": n + 1, "c": n + 2},
+        )
+        self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
+
+    @staticmethod
+    def _multi_kwargs_fn(n, sparse=False):
+        if sparse:
+            return {
+                "a": build_sparse_tensor(),
+                "b": build_sparse_tensor(),
+                "c": build_sparse_tensor(),
+            }
+        else:
+            return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
+
+    @dist_init
+    def test_multi_py_udf_remote(self):
+        self._test_multi_remote_call(
+            my_function, False, kwargs_fn=RpcTest._multi_kwargs_fn
+        )
+
+    @dist_init
+    def test_py_rref_args(self):
+        self._py_rref_args(
+            torch.ones(2, 2), 1, torch.ones(2, 2), 2, torch.ones(2, 2) * 2 + 3
+        )
+
+    @dist_init
+    def test_py_rref_args_user_share(self):
+        self._py_rref_args_user_share(
+            torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10
+        )
+
+    @dist_init
+    def test_py_rpc_rref_args(self):
+        self._py_rpc_rref_args(
+            torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10
+        )
+
+    @dist_init
+    def test_nested_remote(self):
+        self._nested_remote(nested_remote, torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_nested_rref(self):
+        self._nested_rref(nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
+
+    @dist_init
+    def test_nested_rref_stress(self):
+        self._nested_rref_stress(
+            nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2
+        )
+
+    @dist_init
+    def test_multi_layer_nested_async_rpc(self):
+        # This test will exit right away, but there will be a chain of async
+        # RPCs. The termination algorithm should detect those messages properly.
+        # Otherwise, some peer could exit early, leaving others to timeout
+        # errors or connection closed errors.
+        ttl = 20
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
+
+    @dist_init
+    def test_remote_with_exception(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        # check ref to other workers
+        rref = rpc.remote(worker_name(dst_rank), raise_func)
+        with self.assertRaises(ValueError):
+            rref.to_here()
+        # check ref to itself
+        rref = rpc.remote(worker_name(self.rank), no_result, args=(10,))
+        with self.assertRaises(TypeError):
+            rref.to_here()
+
+    @dist_init
+    def test_rpc_return_rref(self):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        rref = rpc.rpc_sync(
+            worker_name(dst_rank1),
+            rpc_return_rref,
+            args=(worker_name(dst_rank2),),
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_rref_forward_chain(self):
+        ttl = 8
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        rref = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1))
+
+        ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
+
+        for _ in range(ttl):
+            self.assertEqual(len(ret_rref), 1)
+            ret_rref = ret_rref[0].to_here()
+
+        ret = ret_rref
+        self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
+
+    @dist_init
+    def test_local_rref_no_fork(self):
+        local_rref = RRef(35)
+        self.assertEqual(local_rref.local_value(), 35)
+
+    @dist_init
+    def test_local_value_not_on_owner(self):
+        # ensure that an error message is thrown if a user tries to call
+        # local_value() on a non-owning node.
+        next_rank = (self.rank + 1) % self.world_size
+        rref = rpc.remote(
+            worker_name(next_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        with self.assertRaisesRegex(
+            RuntimeError,
+            (
+                rf"For UserRRef\(rref_id=GloballyUniqueId\(created_on={self.rank}, local_id=0\), "
+                rf"fork_id=GloballyUniqueId\(created_on={self.rank}, local_id=1\)\), "
+                r"can't call localValue\(\) on user "
+                rf"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). "
+                rf"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)"
+            ),
+        ):
+            rref.local_value()
+
+    @dist_init
+    def test_return_local_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        rref_list = rpc.rpc_sync(
+            worker_name(dst_rank), get_rref_list, args=([1, 2, 3],)
+        )
+
+        for rref in rref_list:
+            rpc.rpc_sync(
+                rref.owner(),
+                _call_method_on_rref,
+                args=(MyClass.increment_value, rref, 10),
+            )
+
+        rets = [
+            rpc.rpc_sync(
+                rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref)
+            )
+            for rref in rref_list
+        ]
+
+        self.assertEqual(rets, [11, 12, 13])
+
+    @dist_init
+    def _test_rref_type(self, blocking):
+        def launched_rpc(events):
+            expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner"
+            return any(e.name.startswith(expected_name) for e in events)
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1))
+
+        with _profile() as p:
+            t = rref._get_type(blocking=blocking)
+            if not blocking:
+                t = t.wait()
+
+        self.assertTrue(launched_rpc(p.function_events))
+        expected_type = type(torch.ones(2))
+        self.assertEqual(t, expected_type)
+
+        futs = []
+
+        def verify(fut):
+            self.assertEqual(fut.value(), expected_type)
+
+        with _profile() as p:
+            for _ in range(10):
+                t = rref._get_type(blocking=blocking)
+                if not blocking:
+                    futs.append(t)
+                    t.add_done_callback(verify)
+                    t = t.wait()
+                self.assertEqual(t, expected_type)
+
+        if not blocking:
+            # Note that cached calls with blocking=False all return the same
+            # cached original future.
+            first_fut = futs[0]
+            for f in futs[1:]:
+                self.assertTrue(f is first_fut)
+        # Ensure we never launch another RPC, other than for the very
+        # first call.
+        self.assertFalse(launched_rpc(p.function_events))
+        self.assertEqual(t, type(torch.ones(2)))
+
+        rref = rpc.remote(dst, MyClass, args=(0,))
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, MyClass)
+
+    def test_rref_type_blocking(self):
+        self._test_rref_type(blocking=True)
+
+    def test_rref_type_non_blocking(self):
+        self._test_rref_type(blocking=False)
+
+    @dist_init
+    def _test_rref_type_with_error(self, blocking):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        # 10 ms timeout
+        rref = rpc.remote(dst, raise_func)
+        # Blocking: error raised inline
+        if blocking:
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                rref._get_type(blocking=blocking)
+        else:
+            # Non-blocking: Immediately return future, block on wait
+            fut = rref._get_type(blocking=blocking)
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut.wait()
+
+    def test_rref_type_with_error_blocking(self):
+        self._test_rref_type_with_error(blocking=True)
+
+    def test_rref_type_with_error_non_blocking(self):
+        self._test_rref_type_with_error(blocking=False)
+
+    @dist_init
+    def _test_rref_type_owner(self, blocking):
+        rref = RRef(torch.ones(2) + 1)
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, type(torch.ones(2)))
+
+        rref = RRef(MyClass(0))
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, MyClass)
+
+    def test_rref_type_owner_blocking(self):
+        self._test_rref_type_owner(blocking=True)
+
+    def test_rref_type_owner_non_blocking(self):
+        self._test_rref_type_owner(blocking=False)
+
+    @staticmethod
+    def _slow_add(x, y):
+        time.sleep(1)
+        return x + y
+
+    @dist_init
+    def test_rref_type_slow_init(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1))
+        self.assertEqual(rref._get_type(), type(torch.ones(2)))
+
+    @dist_init
+    def test_owner_equality(self):
+        a = RRef(40)
+        b = RRef(50)
+
+        other_rank = (self.rank + 1) % self.world_size
+        other_a = rpc.remote(
+            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
+        )
+        other_b = rpc.remote(
+            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
+        )
+        other_a.to_here()  # to ensure clean termination
+        other_b.to_here()
+
+        self.assertNotEqual(a.owner(), 23)
+        self.assertEqual(other_a.owner(), other_b.owner())
+        self.assertNotEqual(a.owner(), other_a.owner())
+        self.assertEqual(other_a.owner(), other_a.owner())
+        self.assertEqual(other_a.owner(), other_b.owner())
+        self.assertEqual(a.owner(), a.owner())
+        self.assertEqual(a.owner(), b.owner())
+        self.assertEqual(a.owner(), rpc.get_worker_info())
+        x = {}
+        x[a.owner()] = a
+        x[other_a.owner()] = other_a
+        self.assertEqual(x[a.owner()], a)
+        self.assertEqual(x[b.owner()], a)
+        self.assertEqual(x[other_a.owner()], other_a)
+        self.assertEqual(x[other_b.owner()], other_a)
+        self.assertEqual(len(x), 2)
+
+    @dist_init
+    def test_pass_local_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        dst_worker = worker_name(dst_rank)
+
+        rref = RRef(40)
+        self.assertEqual(
+            rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90
+        )
+        self.assertEqual(
+            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90
+        )
+        self.assertEqual(
+            rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90
+        )
+
+    @dist_init
+    def test_remote_same_worker(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(
+            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2)
+        )
+        rref_b = rpc.remote(
+            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)
+        )
+        rref_c = rpc.remote(
+            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
+
+    @dist_init(setup_rpc=True)
+    def test_call_method_on_rref(self):
+        """
+        Tests that it is possible to call an instance method on a remote object
+        by using rref.owner() as destination of the call.
+        """
+        vals = [10, 2, 5, 7]
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst_rank)
+
+        # creates a remote object
+        rref = rpc.remote(dst_worker, MyClass, args=(vals[0],))
+
+        # modifies state of the remote object
+        rpc.rpc_sync(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[1]),
+        )
+        rpc.rpc_async(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[2]),
+        ).wait()
+        rpc.remote(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[3]),
+        ).to_here()
+
+        # queries state of the remote object
+        result = rpc.rpc_sync(
+            dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref)
+        )
+
+        self.assertEqual(result, sum(vals))
+
+    # Notice `rpc.api.shutdown()` accesses
+    # `_delete_all_user_and_unforked_owner_rrefs` through
+    # `torch.distributed.rpc.api`, so patching
+    # `torch.distributed.rpc._delete_all_user_and_unforked_owner_rrefs` will
+    # not help.
+    @mock.patch.object(
+        torch.distributed.rpc.api, "_delete_all_user_and_unforked_owner_rrefs"
+    )
+    def _test_rref_leak(
+        self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak
+    ):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        # Wait for all init to complete.
+        dist.barrier()
+
+        rref = rpc.remote(  # noqa: F841
+            worker_name((self.rank + 1) % self.world_size),
+            torch.add,
+            args=(torch.ones(2, 2), 1),
+        )
+
+        import torch.distributed.rpc.api as api
+
+        if ignore_leak:
+            api._ignore_rref_leak = True
+            rpc.shutdown(graceful=True)
+        else:
+            api._ignore_rref_leak = False
+            with self.assertRaisesRegex(RuntimeError, "Leaking RRef"):
+                rpc.shutdown(graceful=True)
+
+    @dist_init(setup_rpc=False)
+    def test_rref_leak(self):
+        self._test_rref_leak(ignore_leak=False)
+
+    @dist_init(setup_rpc=False)
+    def test_ignore_rref_leak(self):
+        self._test_rref_leak(ignore_leak=True)
+
+    @dist_init
+    def test_rref_str(self):
+        rref1 = RRef(self.rank)
+        id_class = "GloballyUniqueId"
+        self.assertEqual(
+            f"OwnerRRef({id_class}(created_on={self.rank}, local_id=0))",
+            rref1.__str__(),
+        )
+
+        dst_rank = (self.rank + 1) % self.world_size
+        rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        self.assertEqual(
+            rref2.__str__(),
+            f"UserRRef(RRefId = {id_class}(created_on={self.rank}, local_id=1), "
+            f"ForkId = {id_class}(created_on={self.rank}, local_id=2))",
+        )
+
+    @dist_init
+    def test_rref_get_future(self):
+        # Tests that we can obtain the future corresponding to the creation of
+        # the RRef on remote end
+        if self.rank == 0:
+            # Builtin
+            rref = rpc.remote(worker_name(1), torch.add, args=(1, 1))
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+            # UDF
+            rref = rpc.remote(worker_name(1), foo_add, args=())
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+            # Script
+            rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1),))
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+    @dist_init
+    def test_rref_context_debug_info(self):
+        # This test checks local states that are modified by remote workers.
+        # This means that we would need barrier before and after every check.
+        # The barrier before the check makes sure that all previous states are
+        # cleared globally, the barrier after ensures that no following states
+        # change gets into the current check.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # Check 1: local RRef does not update owners_ map or add a pending user.
+        #################################################
+
+        rref1 = RRef(self.rank)
+
+        # don't need a barrier here as local RRef is handled by this thread
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertIn("num_pending_users", info)
+        # RRef on local value is not added to context until shared across RPC
+        self.assertEqual(0, int(info["num_owner_rrefs"]))
+        self.assertEqual(0, int(info["num_pending_users"]))
+        # barrier after the check 1
+        dist.barrier()
+
+        # Check 2: Sharing RRef as an arg should update owners_ map
+        ###########################################################
+
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,))
+
+        # barrier before check 2
+        wait_until_pending_futures_and_users_flushed()
+        dist.barrier()
+
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertEqual(1, int(info["num_owner_rrefs"]))
+        # no pending users since the fork is finished
+        self.assertEqual(0, int(info["num_pending_users"]))
+        # barrier after check 2
+        dist.barrier()
+
+        # clear states for check 2
+        rpc.rpc_sync(worker_name(dst_rank), clear_global_rref)
+
+        # Wait for owner rref to be cleared.
+        while int(info["num_owner_rrefs"]) != 0:
+            info = _rref_context_get_debug_info()
+            time.sleep(0.1)
+        dist.barrier()
+
+        # Check 3: rpc.remote call should update owners_ map
+        ####################################################
+        rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        rref3 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        rref2.to_here()
+        rref3.to_here()
+
+        # barrier before check 3
+        wait_until_pending_futures_and_users_flushed()
+        dist.barrier()
+
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertEqual(2, int(info["num_owner_rrefs"]))
+        # no pending users since the fork is finished
+        self.assertEqual(0, int(info["num_pending_users"]))
+
+        # barrier after check 3
+        dist.barrier()
+
+    @dist_init
+    def test_disable_gil_profiling(self):
+        # test that rpc.enable_gil_profiling(false) will result in
+        # GIL wait time not being recorded.
+
+        # GIL profiling should be disabled by default.
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc.rpc_sync(
+            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"])
+        rpc.enable_gil_profiling(True)
+        rpc.rpc_sync(
+            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertIn("agent.gil_average_wait_time_us", info)
+
+    @dist_init(setup_rpc=False)
+    def test_local_shutdown(self):
+        # test that we can start RPC and then immediately locally shutdown
+        # without sending any messages.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        # pass in graceful=False to ensure that we don't wait for other workers.
+        rpc.shutdown(graceful=False)
+
+    @dist_init
+    def test_debug_info(self):
+        # only test keys in this test case. Values should be covered by
+        # individual module debug info tests
+        import torch.distributed.autograd as dist_autograd
+
+        info = _get_debug_info()
+        rref_info = _rref_context_get_debug_info()
+        agent_info = rpc.api._get_current_rpc_agent().get_debug_info()
+        autograd_info = dist_autograd._get_debug_info()
+        common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys()
+        self.assertEqual(0, len(common_keys))
+        expected = {}
+        expected.update(rref_info)
+        expected.update(agent_info)
+        expected.update(autograd_info)
+        # NB: Key ordering is only preserved in python 3.6+. So here, we
+        # manually check keys are equal.
+        for key in expected:
+            self.assertIn(key, info.keys())
+
+        for key in info:
+            self.assertIn(key, expected.keys())
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_handle_send_exceptions(self):
+        # test that if a callee node has gone down, we raise an appropriate
+        # exception instead of just crashing.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc._set_rpc_timeout(10)
+        # This barrier is needed to ensure that some workers do not exit before
+        # others have been brought up.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+        if self.rank == 1:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker = worker_name(dst_rank)
+            # allow destination worker to exit without joining
+            error_str = self.get_shutdown_error_regex()
+            wait_until_node_failure(dst_rank, error_str)
+            fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3))
+            # Shutdown sequence is not very well defined and as a result
+            # we can see any of the error messages defined in get_shutdown_error_regex.
+            with self.assertRaisesRegex(RuntimeError, error_str):
+                fut.wait()
+        # exit all workers non-gracefully.
+        rpc.shutdown(graceful=False)
+
+    @dist_init
+    def test_deadlock(self):
+        # this test is copied from https://github.com/pytorch/pytorch/issues/45089
+        if self.rank == 1:
+            dst1 = worker_name((self.rank + 1) % self.world_size)
+            x = torch.ones(2)
+            y = torch.ones(2)
+            rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait()
+
+        dist_initialized = dist.is_initialized()
+        if not dist_initialized:
+            dist.init_process_group(
+                backend="gloo",
+                init_method=self.file_init_method,
+                rank=self.rank,
+                world_size=self.world_size,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_local_shutdown_with_rpc(self):
+        # test that we can start RPC, send RPCs, and then run local shutdown.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        # A barrier is needed to ensure that all RPCs are processed.
+        # Otherwise, some RPCs can timeout since the receiving end
+        # has terminated.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+        # pass in graceful=False to ensure that we don't wait for other workers.
+        rpc.shutdown(graceful=False)
+
+    @dist_init(setup_rpc=False)
+    def test_set_and_get_default_rpc_timeout(self):
+        timeout = 0.5
+
+        # A new `RpcBackendOptions` is constructed
+        # when accessing `self.rpc_backend_options`.
+        rpc_backend_options = self.rpc_backend_options
+        rpc_backend_options.rpc_timeout = timeout
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+        set_timeout = rpc.get_rpc_timeout()
+        self.assertEqual(timeout, set_timeout)
+        rpc.shutdown()
+
+    @dist_init
+    def test_default_timeout_used(self):
+        """
+        Tests that if no timeout is passed into rpc_async and rpc_sync, then the
+        default timeout is used.
+        """
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc._set_rpc_timeout(0.001)  # 1 ms
+        # futures should time out and be marked with an exception indicating it as such.
+        futs = [
+            rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=())
+            for _ in range(10)
+        ]
+        expected_error = self.get_timeout_error_regex()
+        for fut in futs:
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                fut.wait()
+
+        # ensure that if a new timeout is set old futures don't time out but new ones do.
+        rpc._set_rpc_timeout(200)  # 200 seconds
+        # create a longstanding RPC.
+        fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
+        # now, set a short timeout.
+        rpc._set_rpc_timeout(0.001)
+        # fut2 should time out, fut1 should not.
+        fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut2.wait()
+        fut1.wait()
+
+        # Zero timeout means infinity, so future should run to completion.
+        rpc._set_rpc_timeout(0)
+        rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait()
+
+        # reset to default timeout so shutdown messages can process cleanly.
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init
+    def test_rpc_timeouts(self):
+        # TODO: enable timeouts for rpc.remote/RRef (https://github.com/pytorch/pytorch/issues/33803)
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst_rank)
+        timeout = 0.1  # 100 ms
+        expected_error = self.get_timeout_error_regex()
+        # Test async UDF
+        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if there is no timeout and we use the default
+        # RPC timeout.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait()
+
+        # Test sync UDF
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
+
+        # Ensure run to completion if there is no timeout and we use the default
+        # RPC timeout.
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
+
+        # If we set a default timeout for RPCs, it should be respected, though
+        # still overridden if we pass in a different timeout to the APIs.
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
+
+        # The RPCs should run to completion since we override the timeout.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait()
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5)
+        # Passing in a zero timeout should ensure that the RPC won't time out.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait()
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0)
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    def test_dist_init_decorator(self):
+        @dist_init(setup_rpc=False)
+        def test_func(self):
+            return "expected result"
+
+        self.assertEqual(test_func(self), "expected result")
+
+        @dist_init
+        def test_func(self):
+            return "expected result"
+
+        self.assertEqual(test_func(self), "expected result")
+
+    def test_use_rpc_pickler(self):
+        class TestPickler:
+            pass
+
+        test_pickler = TestPickler()
+        with _use_rpc_pickler(test_pickler):
+            self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
+        self.assertTrue(
+            torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler
+        )
+
+    @dist_init
+    def test_wait_all(self):
+        with _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
+            self.assertTrue(len(_thread_local_var.future_list) == 1)
+            self.assertTrue(
+                isinstance(_thread_local_var.future_list[0], torch._C.Future)
+            )
+        self.assertTrue(fut.done())
+        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_multiple_call(self):
+        with _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            for i in range(20):
+                fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1))
+                res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1))
+                self.assertEqual(res, torch.ones(i, i) + 1)
+                self.assertEqual(fut.wait(), torch.ones(i, i) + 1)
+            self.assertTrue(len(_thread_local_var.future_list) == 20)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_timeout(self):
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error), _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            timeout = 0.1  # 100 ms
+            rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_raise_in_user_func(self):
+        with self.assertRaises(ValueError), _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            rpc.rpc_async(dst, raise_func)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_raise_in_body(self):
+        with self.assertRaises(ValueError), _wait_all():
+            raise_func()
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_custom_exception_throw_during_reconstruction(self):
+        """
+        Test that we still throw info about the remote side exception even when
+        we cannot recreate it on client side.
+        """
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        if self.rank != 0:
+            exc_caught = False
+            dst = worker_name(0)
+            try:
+                rpc.rpc_sync(dst, custom_raise_func, args=())
+            except RuntimeError as e:
+                exc_caught = True
+                msg = str(e)
+                print(f"Got msg {msg}")
+                self.assertTrue("Original exception on remote side was" in msg)
+                self.assertTrue("CustomException" in msg)
+            except BaseException as e:  # noqa: B036
+                raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e
+            finally:
+                self.assertTrue(exc_caught)
+
+        dist.barrier()
+
+    timed_out_rpc_event = None
+
+    @staticmethod
+    def timed_out_rpc():
+        RpcTest.timed_out_rpc_event.wait()
+
+    @dist_init
+    def test_wait_all_exit_early_python(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, raise_func)
+        fut3 = rpc.rpc_async(dst, raise_func)
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(ValueError, expected_err):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_wait_all_exit_early_builtin(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
+        fut3 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(RuntimeError, "size of tensor"):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_wait_all_exit_early_script_function(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
+        fut3 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(RuntimeError, expected_err):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_function_not_on_callee(self):
+        # test that if a function does not exist on a callee, we don't crash,
+        # instead we get an AttributeError indicating that the func does not exist.
+        this_module = sys.modules[__name__]
+        caller_worker = "worker0"
+        callee_worker = "worker1"
+
+        if self.rank == 1:
+            # Use delattr to remove the binding of a func on this nodes
+            delattr(this_module, "foo_add")
+            # notify remote end that we have removed it.
+            rpc.rpc_sync(caller_worker, set_value, args=(self.rank,))
+
+        if self.rank == 0:
+            # func exists on caller, but not callee.
+            # wait for remote end to remove the binding of foo_add func.
+            wait_for_value_future()
+            # Ensure that we have the attribute on this module. Otherwise, the test could fail due to a caller-side pickling error.
+            self.assertTrue(hasattr(this_module, "foo_add"))
+            with self.assertRaisesRegex(RuntimeError, "RPC pickler does not serialize"):
+                rpc.rpc_sync(callee_worker, foo_add, args=())
+
+    @dist_init
+    def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        a = MyClass(1)
+        b = MyClass(2)
+
+        # This is to make Python not garbage collect a and b.
+        a.other = b
+        b.other = a
+
+        n = self.rank
+        a.rref = rpc.remote(dst_worker_name, torch.add, args=(torch.ones(n, n), 2))
+
+    @dist_init(setup_rpc=False)
+    def test_use_rref_after_shutdown(self):
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        # pass in graceful=True to ensure that local UserRRefs are deleted.
+        rpc.shutdown(graceful=True)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Cannot call to_here\\(\\) on it after deletion."
+        ):
+            rref.to_here()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Cannot call fork an UserRRef after deletion."
+        ):
+            import torch.distributed.rpc.internal as internal
+
+            internal.serialize(rref)
+
+    @staticmethod
+    def _return_gpu_tensor():
+        return torch.rand(3, 3).cuda(0)
+
+    @staticmethod
+    def _return_gpu_tensor_list():
+        return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)]
+
+    @staticmethod
+    def _gpu_tensor_list_arg(tensor_list):
+        return torch.rand(3, 3)
+
+    def _create_rref(self):
+        owner_rank = (self.rank + 2) % self.world_size
+        return rpc.remote(
+            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
+        )
+
+    @dist_init
+    def test_user_rrefs_confirmed(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret = rpc.rpc_sync(worker_name(dst_rank), check_rref_confirmed, args=(rref,))
+        self.assertEqual(ret, True)
+
+    @dist_init
+    def test_user_rrefs_confirmed_remote(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret_rref = rpc.remote(worker_name(dst_rank), check_rref_confirmed, args=(rref,))
+        self.assertEqual(ret_rref.to_here(), True)
+
+    @dist_init
+    def test_rref_py_pickle_not_supported(self):
+        local_rref = RRef(35)
+        with (
+            TemporaryFileName() as fname,
+            self.assertRaisesRegex(
+                RuntimeError, "Can not pickle rref in python pickler"
+            ),
+        ):
+            torch.save(local_rref, fname)
+
+    @dist_init
+    def test_remote_throw(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            raise_or_inc,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            rref.to_here()
+
+    @dist_init
+    def test_non_cont_tensors(self):
+        if self.rank == 0:
+            # Create a non-contiguous tensor.
+            t = torch.rand(5, 5)
+            t_view = t.narrow(1, 2, 2)
+            self.assertFalse(t_view.is_contiguous())
+            t_cont = t_view.contiguous()
+            self.assertTrue(t_cont.is_contiguous())
+            self.assertEqual(t_view, t_cont)
+
+            # Send non-cont tensor over RPC.
+            next_rank = (self.rank + 1) % self.world_size
+            t_ret = rpc.rpc_sync(
+                worker_name(next_rank), non_cont_test, args=(t_view, t_cont)
+            )
+
+            # Verify the returned tensor.
+            self.assertEqual(t_view, t_ret)
+            self.assertFalse(t_ret.is_contiguous())
+
+    @dist_init
+    def test_callback_simple(self):
+        set_by_cb = concurrent.futures.Future()
+        n = self.rank + 1
+
+        def callback(fut):
+            ret = fut.wait()
+            self.assertEqual(ret, torch.ones(n, n) * 2)
+            set_by_cb.set_result(ret.clone() + 1)
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+        self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1)
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_callback_wrong_arg_num(self):
+        n = self.rank + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        cb_fut = fut.then(my_function)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "my\\_function\\(\\) missing 2 required positional arguments"
+        ):
+            cb_fut.wait()
+
+    @dist_init
+    def test_callback_wrong_arg_type(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1))
+        fut1 = fut0.then(lambda x: x + 1)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "unsupported operand type\\(s\\) for \\+"
+        ):
+            fut1.wait()
+
+    @dist_init
+    def test_callback_multi(self):
+        num_cbs = 10
+        n = self.rank + 1
+
+        def callback(idx, fut):
+            ret = fut.wait()
+            self.assertEqual(ret, torch.ones(n, n) * 2)
+            return ret + idx
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        cb_futs = [fut.then(partial(callback, idx)) for idx in range(num_cbs)]
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        for idx in range(num_cbs):
+            self.assertEqual(cb_futs[idx].wait(), torch.ones(n, n) * 2 + idx)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_callback_chain(self):
+        n = self.rank + 1
+
+        def callback(fut):
+            return fut.wait() + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size), torch.add, args=(torch.ones(n, n), 1)
+        )
+
+        num_cbs = 20
+        for _ in range(num_cbs):
+            fut = fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
+
+    @dist_init
+    def test_callback_in_rpc(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, add_use_future_cb, args=(dst2, torch.ones(2, 2), 1, 2))
+        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
+
+    @dist_init
+    def test_callback_with_ret(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        def callback(fut0):
+            fut2 = rpc.rpc_async(dst, torch.add, args=(fut0.wait(), 1)).then(
+                lambda fut1: fut1.wait() + 1
+            )
+
+            return fut2.wait()
+
+        fut3 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)).then(callback)
+
+        self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_callback_with_error(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        def callback(fut0):
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut0.wait()
+            raise RuntimeError("Another expected error")
+
+        fut1 = rpc.rpc_async(dst, raise_func).then(callback)
+        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
+            fut1.wait()
+
+    @dist_init
+    def test_callback_none(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(TypeError, "incompatible function arguments."):
+            rpc.rpc_async(dst, raise_func).then(None)
+
+    @dist_init
+    def test_add_done_callback(self):
+        set_by_cb = False
+        n = self.rank + 1
+
+        def callback(fut):
+            nonlocal set_by_cb
+            fut.wait()
+            set_by_cb = True
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        fut.add_done_callback(callback)
+        fut_then = fut.then(lambda _: True)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
+        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
+        fut_then.wait()
+        self.assertTrue(set_by_cb)
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_mark_future_twice(self):
+        fut = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            torch.add,
+            args=(torch.zeros(2, 2), 1),
+        )
+        self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1)
+        with self.assertRaisesRegex(
+            RuntimeError, "Future can only be marked completed once"
+        ):
+            fut.set_result(1)
+
+    @dist_init
+    def test_pickle_future(self):
+        fut = torch.futures.Future()
+        errMsg = "Can not pickle torch.futures.Future"
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg):
+            rpc.rpc_sync(dst, fail_on_fut, args=(fut,))
+
+        with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg):
+            rpc.rpc_async(dst, fail_on_fut, args=(fut,))
+
+        with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg):
+            rpc.remote(dst, fail_on_fut, args=(fut,))
+
+    @dist_init
+    def test_future_done(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1))
+        fut.wait()
+        self.assertTrue(fut.done())
+
+    @dist_init
+    def test_future_done_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut = rpc.rpc_async(dst, raise_func)
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            fut.wait()
+        self.assertTrue(fut.done())
+
+    def _test_future_cb(self, func):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, func, args=(dst2, torch.ones(2, 2), 1, 2))
+        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
+
+    @dist_init
+    def test_future_in_rpc(self):
+        self._test_future_cb(add_use_future_set_result)
+
+    @dist_init
+    def test_future_nested_callback(self):
+        self._test_future_cb(add_use_future_nested_cb)
+
+    def _test_async_function_raise(self, mode):
+        with self.assertRaisesRegex(RuntimeError, "Expected error"):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), async_raise_func, mode
+            )
+
+    @dist_init
+    def test_async_function_raise(self):
+        self._test_async_function_raise(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_async_function_raise_async(self):
+        self._test_async_function_raise(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_raise_remote(self):
+        self._test_async_function_raise(RPCExecMode.REMOTE)
+
+    def _test_async_function_wrong_return_type(self, mode):
+        errMsg = (
+            "Functions decorated with @rpc\\.async_function must return a "
+            "torch\\.futures\\.Future object,"
+        )
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), async_wrong_type, mode
+            )
+
+    @dist_init
+    def test_async_function_wrong_return_type(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_async_function_wrong_return_type_async(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_wrong_return_type_remote(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_simple(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1))
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    def _test_async_function(self, fn, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        args = (dst2, torch.ones(2, 2), 1, 2)
+        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
+        self.assertEqual(ret, torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_async_function_with_future_ctor(self):
+        self._test_async_function(async_add_with_future_ctor)
+
+    @dist_init
+    def test_async_function_with_future_ctor_remote(self):
+        self._test_async_function(async_add_with_future_ctor, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_chained(self):
+        self._test_async_function(async_add_chained)
+
+    @dist_init
+    def test_async_function_chained_remote(self):
+        self._test_async_function(async_add_chained, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_nested(self):
+        self._test_async_function(async_add_nested)
+
+    @dist_init
+    def test_async_function_nested_remote(self):
+        self._test_async_function(async_add_nested, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_static_method(self):
+        self._test_async_function(AsyncExecutionClass.static_async_add)
+
+    @dist_init
+    def test_async_static_method_remote(self):
+        self._test_async_function(
+            AsyncExecutionClass.static_async_add, RPCExecMode.REMOTE
+        )
+
+    @dist_init
+    def test_async_class_method(self):
+        self._test_async_function(AsyncExecutionClass.class_async_add)
+
+    @dist_init
+    def test_async_class_method_remote(self):
+        self._test_async_function(
+            AsyncExecutionClass.class_async_add, RPCExecMode.REMOTE
+        )
+
+    def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+        rref = rpc.remote(dst1, AsyncExecutionClass)
+
+        x = torch.ones(2, 2)
+        y = torch.ones(2, 2) + 1
+        if mode == RPCExecMode.SYNC:
+            ret = rref.rpc_sync().static_async_add(dst2, x, x, y)
+            ret += rref.rpc_sync().class_async_add(dst2, x, x, y)
+            ret += rref.rpc_sync().bound_async_add(dst2, x, x, y)
+        elif mode == RPCExecMode.ASYNC:
+            ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait()
+            ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait()
+            ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait()
+        elif mode == RPCExecMode.REMOTE:
+            ret = rref.remote().static_async_add(dst2, x, x, y).to_here()
+            ret += rref.remote().class_async_add(dst2, x, x, y).to_here()
+            ret += rref.remote().bound_async_add(dst2, x, x, y).to_here()
+
+        self.assertEqual(ret, 3 * 4 * x)
+
+    @dist_init
+    def test_async_class_rref_proxy(self):
+        self._test_test_async_class_rref_proxy()
+
+    @dist_init
+    def test_async_class_rref_proxy_async(self):
+        self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_class_rref_proxy_remote(self):
+        self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE)
+
+    def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        num = 20
+        step = 3
+        args = (dst2, torch.ones(2, 2), num, step)
+        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
+        self.assertEqual(ret, torch.ones(2, 2) + num * step)
+
+    @dist_init
+    def test_async_function_multi_chained(self):
+        self._test_async_function_multi(async_add_chained_multi)
+
+    @dist_init
+    def test_async_function_multi_chained_async(self):
+        self._test_async_function_multi(async_add_chained_multi, RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_multi_chained_remote(self):
+        self._test_async_function_multi(async_add_chained_multi, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_multi_fanout(self):
+        self._test_async_function_multi(async_add_multi_fanout)
+
+    @dist_init
+    def test_async_function_multi_fanout_async(self):
+        self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_multi_fanout_remote(self):
+        self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.REMOTE)
+
+    def _test_return_future(self, mode):
+        with self.assertRaisesRegex(
+            RuntimeError, "Can not pickle torch.futures.Future"
+        ):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), return_future, mode
+            )
+
+    @dist_init
+    def test_return_future(self):
+        self._test_return_future(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_return_future_async(self):
+        self._test_return_future(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_return_future_remote(self):
+        self._test_return_future(RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_rref_timeout(self):
+        # This test is similar to ones in FaultyProcessGroupTest, but is meant to be
+        # run with other backends besides ProcessGroup.
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # 10 ms timeout
+        rref = rpc.remote(dst_worker, my_sleep_func, args=(2,), timeout=0.01)
+        # Future corresponding to the remote creation should time out.
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref._get_future().wait()
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+
+        wait_until_owners_and_forks_on_rank(1, 1, rank=1)
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.",
+    )
+    def test_init_pg_then_rpc(self):
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        # Test RPC.
+        next_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test PG
+        dist.barrier()
+
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.",
+    )
+    def test_init_rpc_then_pg(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        # Test RPC.
+        next_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test PG
+        dist.barrier()
+
+        rpc.shutdown()
+
+    @dist_init
+    def test_wait_all_with_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [rpc.rpc_async(dst, raise_func) for _ in range(10)]
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+    @dist_init
+    def test_wait_all_with_partial_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [
+            rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1)) for _ in range(10)
+        ]
+
+        futs.append(rpc.rpc_async(dst, raise_func))
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491",
+    )
+    def test_init_rpc_twice(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc.shutdown()
+
+        # Wait for all init to complete.
+        dist.barrier()
+
+        # Use a different file name for the next initialization
+        new_backend_options = self.rpc_backend_options
+        new_backend_options.init_method += "init_2"
+
+        # Ensure rpc initialization works again.
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=new_backend_options,
+        )
+
+        # Verify RPCs work after re-init.
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
+        rpc.rpc_sync(dst, foo_add, args=())
+
+        rpc.shutdown()
+
+    def test_wrong_types(self):
+        with self.assertRaisesRegex(
+            TypeError,
+            "Argument backend must be a member of BackendType",
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend="TENSORPIPE",
+            )
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "Argument rpc_backend_options must be an instance of RpcBackendOptions",
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend=self.rpc_backend,
+                rpc_backend_options={"init_method": self.init_method},
+            )
+
+    def test_cannot_infer_backend_from_options(self):
+        # An exception should be raised if the backend isn't specified but
+        # options are given which are not an instance of any of the known
+        # agents' option classes.
+        rpc_backend_options = FooBackendOptions(self.init_method)
+
+        with self.assertRaisesRegex(TypeError, "Could not infer backend for options"):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                # Do _not_ pass backend.
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    @dist_init
+    def test_owner_rref_backward(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        t1 = torch.rand(10, 10, requires_grad=True)
+        rref = rpc.RRef(t1.sum() + t1.sum())
+        rref.backward()
+        expected_grad = torch.ones_like(t1) * 2
+        self.assertEqual(expected_grad, t1.grad)
+
+        with dist_autograd.context() as context_id:
+            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
+            rref = rpc.RRef(t2.sum())
+            rref.backward(context_id)
+            self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1])
+
+        # Double backward.
+        with dist_autograd.context() as context_id:
+            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
+            rref = rpc.RRef(t2.sum())
+            rref.backward(context_id, retain_graph=True)
+            rref.backward(context_id)
+            self.assertEqual(
+                expected_grad * 2, dist_autograd.get_gradients(context_id)[t1]
+            )
+
+        # Test errors.
+        with self.assertRaisesRegex(
+            RuntimeError, "tensors does not require grad and does not have a grad_fn"
+        ):
+            rpc.RRef(torch.rand(10)).backward()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "grad can be implicitly created only for scalar outputs"
+        ):
+            rpc.RRef(torch.rand(10, requires_grad=True)).backward()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Could not find autograd context with id: 100"
+        ):
+            rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "RRef should contain a tensor for .backward()"
+        ):
+            rpc.RRef("foo").backward()
+
+    @staticmethod
+    def _sum(x):
+        return x.sum()
+
+    @staticmethod
+    def _identity(x):
+        return x
+
+    @dist_init
+    def test_user_rref_backward(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        t = torch.rand(10, requires_grad=True)
+        with dist_autograd.context() as context_id:
+            rref = rpc.remote(dst, RpcTest._sum, args=(t,))
+            rref.backward(context_id, retain_graph=True)
+            rref.backward(context_id)
+            self.assertEqual(
+                torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t]
+            )
+
+        with dist_autograd.context() as context_id:
+            rref = rpc.remote(dst, RpcTest._identity, args=("foo",))
+            with self.assertRaisesRegex(
+                RuntimeError, "RRef should contain a tensor for .backward()"
+            ):
+                rref.backward(context_id)
+
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "User RRefs require 'dist_autograd_ctx_id' to be specified",
+            ):
+                rref.backward()
+
+    @dist_init(setup_rpc=False)
+    def test_shutdown_errors(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        if self.rank != 0:
+            og_func = rpc.api._broadcast_to_followers
+            og_rref_func = rpc.api._delete_all_user_and_unforked_owner_rrefs
+
+            # Monkey-patch _broadcast_to_followers to fail, which would ensure
+            # _all_gather on leader raises an exception.
+            def raise_error(sequence_id, objects_map):
+                og_func(sequence_id, objects_map)
+                raise RuntimeError("simulation")
+
+            # Monkey-patch _delete_all_user_and_unforked_owner_rrefs to fail,
+            # which would ensure barrier is not called on followers.
+            def rref_error():
+                raise RuntimeError("simulation rref")
+
+            try:
+                rpc.api._broadcast_to_followers = raise_error
+                rpc.api._delete_all_user_and_unforked_owner_rrefs = rref_error
+                with self.assertRaisesRegex(RuntimeError, "simulation rref"):
+                    rpc.shutdown()
+            finally:
+                rpc.api._broadcast_to_followers = og_func
+                rpc.api._delete_all_user_and_unforked_owner_rrefs = og_rref_func
+        else:
+            with self.assertRaisesRegex(RuntimeError, "timed out in _all_gather"):
+                rpc.shutdown()
+
+        dist.barrier()
+
+    @dist_init
+    def test_my_parameter_server(self):
+        self._my_parameter_server(False)
+
+
+class CudaRpcTest(RpcAgentTestFixture):
+    @skip_if_lt_x_gpu(2)
+    @dist_init
+    def test_profiler_remote_cuda(self):
+        if self.rank != 1:
+            return
+
+        dst_cuda_0 = (self.rank + 1) % self.world_size
+        dst_cuda_1 = (self.rank + 2) % self.world_size
+        dst_worker_cuda_0 = worker_name(dst_cuda_0)
+        dst_worker_cuda_1 = worker_name(dst_cuda_1)
+
+        with _profile(use_cuda=True) as p:
+            fut1 = rpc.rpc_async(dst_worker_cuda_0, udf_with_torch_ops, args=(0,))
+            fut2 = rpc.rpc_async(dst_worker_cuda_1, udf_with_torch_ops, args=(1,))
+            fut1.wait()
+            fut2.wait()
+
+        def get_name(event):
+            return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :]
+
+        function_events = p.function_events
+        for event in function_events:
+            if event.is_async:
+                self.assertEqual(0, event.device_time_total)
+                self.assertEqual([], event.kernels)
+                self.assertEqual(0, event.device_time)
+            else:
+                if event.node_id == 1:
+                    continue
+                self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1])
+                if get_name(event) in EXPECTED_REMOTE_EVENTS:
+                    self.assertGreater(event.device_time_total, 0)
+                    self.assertEqual(1, len(event.kernels))
+                    kernel = event.kernels[0]
+                    if event.node_id == dst_cuda_0:
+                        self.assertEqual(kernel.device, 0)
+                    if event.node_id == dst_cuda_1:
+                        self.assertEqual(kernel.device, 1)
+                    self.assertGreater(event.device_time, 0)
+
+        # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled
+        # events.
+        remote_events = [event for event in function_events if event.is_remote]
+        remote_event_names = [
+            get_name(event)
+            for event in remote_events
+            if get_name(event) in EXPECTED_REMOTE_EVENTS
+        ]
+        self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS))
+
+
+class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon):
+    def test_mismatched_type_for_options(self):
+        # An exception should be raised if the options are not an instance of
+        # TensorPipeRpcBackendOptions.
+        rpc_backend_options = FooBackendOptions(self.init_method)
+
+        with self.assertRaisesRegex(
+            TypeError, "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`"
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend=rpc.BackendType.TENSORPIPE,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    def test_infer_backend_from_options(self):
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.init_method, _transports=tp_transports()
+        )
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            rank=self.rank,
+            world_size=self.world_size,
+            # Do _not_ pass backend.
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent)
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_set_and_get_num_worker_threads(self):
+        NUM_THREADS = 27
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.rpc_backend_options.init_method,
+            num_worker_threads=NUM_THREADS,
+            _transports=tp_transports(),
+        )
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS)
+        rpc.shutdown()
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_tensorpipe_set_default_timeout(self):
+        # Set a high timeout since it doesn't affect test runtime and ensures
+        # the test doesn't erroneously timeout due to slow machines.
+        timeout = 100
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.rpc_backend_options.init_method,
+            num_worker_threads=self.rpc_backend_options.num_worker_threads,
+            rpc_timeout=timeout,
+            _transports=tp_transports(),
+        )
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        default_timeout = rpc.get_rpc_timeout()
+        self.assertEqual(default_timeout, timeout)
+        rpc.shutdown()
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_tensorpipe_options_throw_on_timedelta_timeout(self):
+        from datetime import timedelta
+
+        timeout = timedelta()
+        # Ensure that constructing TensorPipeRpcBackendOptions with timedelta fails
+        with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"):
+            rpc.TensorPipeRpcBackendOptions(
+                init_method=self.rpc_backend_options.init_method,
+                num_worker_threads=self.rpc_backend_options.num_worker_threads,
+                rpc_timeout=timeout,
+            )
+
+    @dist_init
+    def _test_rref_get_type_timeout(self, blocking):
+        # Test where we try to get the type of a RRef from an owner, but RRef
+        # creation is slower than timeout passed into _get_type.
+        dst_rank = (self.rank + 1) % self.world_size
+        dst = worker_name(dst_rank)
+        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
+        timeout = 0.5
+        expected_err = self.get_timeout_error_regex()
+        # Blocking: blocks on inline call
+        if blocking:
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                slow_rref._get_type(timeout=timeout, blocking=blocking)
+        # Non-blocking: blocks on wait
+        else:
+            fut = slow_rref._get_type(timeout=timeout, blocking=blocking)
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                fut.wait()
+
+        # FIXME We wait until the remote completed creating the OwnerRRef
+        # because there's currently a race if we shut down RPC before that.
+        slow_rref.to_here()
+
+    def test_rref_get_type_timeout_blocking(self):
+        self._test_rref_get_type_timeout(blocking=True)
+
+    def test_rref_get_type_timeout_non_blocking(self):
+        self._test_rref_get_type_timeout(blocking=False)
+
+    @dist_init
+    def test_op_with_invalid_args(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Overloaded torch operator invoked from Python failed to match any schema",
+        ):
+            rpc.rpc_sync(dst, torch.add, args=())
+
+    def _test_rref_proxy_timeout(self, rref_proxy_api):
+        dst_rank = (self.rank + 1) % self.world_size
+        dst = worker_name(dst_rank)
+        rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2),))
+        # Ensure RRef is created on remote node.
+        rref.to_here()
+        rref_api = getattr(rref, rref_proxy_api)
+        self.assertTrue(
+            rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}"
+        )
+        expected_error = self.get_timeout_error_regex()
+        timeout = 2
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2))
+            if rref_api == rref.rpc_async:
+                result.wait()
+            elif rref_api == rref.remote:
+                result._get_future().wait()
+
+        # Case where rpc.remote() is stuck and exceeds timeout
+        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
+        timeout = 0.01
+        rref_api = getattr(slow_rref, rref_proxy_api)
+        # Note that even when we call rref.rpc_async() in this case, we
+        # time out in future creation, not waiting for future. This is because
+        # rref proxy function calls rref._get_type before returning future,
+        # which blocks on the RRef being created on owner node, until the
+        # specified timeout.
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2))
+            # rpc_async returns immediately and surface a timeout through wait()
+            if rref_api == slow_rref.rpc_async:
+                result.wait()
+
+        # FIXME We wait until the remote completed creating the OwnerRRef
+        # because there's currently a race if we shut down RPC before that.
+        slow_rref.to_here()
+
+    @dist_init
+    def test_rref_proxy_timeout(self):
+        for rpc_api in ["rpc_sync", "rpc_async", "remote"]:
+            self._test_rref_proxy_timeout(rpc_api)
+
+    @dist_init
+    def test_send_to_rank_sparse(self):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        # Test sparse tensor
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor()
+            y = build_sparse_tensor()
+            expected_tensor = x + y
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor(coalesce=True)
+            y = build_sparse_tensor(coalesce=True)
+            expected_tensor = x + y
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
+    @dist_init
+    def test_self_py_udf_remote_sparse(self):
+        self._self_py_udf_remote(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_rpc_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_rpc_arg(
+            dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_rpc_arg_sparse(self):
+        self._self_remote_rref_as_rpc_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_remote_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_remote_arg(
+            dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg_sparse(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    def test_world_size_one_sparse(self):
+        self._world_size_one(build_sparse_tensor(), build_sparse_tensor())
+
+    @dist_init
+    def test_multi_rpc_sparse(self):
+        self._multi_rpc(True)
+
+    def test_wait_all_workers_sparse(self):
+        self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor())
+
+    def test_wait_all_workers_twice_sparse(self):
+        self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor())
+
+    @dist_init
+    def test_py_sparse_tensors_in_container(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        a = [build_sparse_tensor(), build_sparse_tensor()]
+        ret = rpc.rpc_sync(worker_name(dst_rank), my_container_sum, args=(a,))
+        self.assertEqual(ret, my_container_sum(a))
+
+    @dist_init
+    def test_nested_rpc_sparse(self):
+        self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2)
+
+    @dist_init
+    def test_stress_heavy_rpc_sparse(self):
+        self._stress_test_rpc(
+            heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),)
+        )
+
+    @dist_init
+    def test_builtin_remote_ret_sparse(self):
+        self._builtin_remote_ret(
+            build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_self_sparse(self):
+        self._builtin_remote_self(
+            build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_multi_builtin_remote_ret_sparse(self):
+        self._test_multi_remote_call(torch.add, True, args_fn=RpcTest._multi_args_fn)
+
+    @dist_init
+    def test_multi_py_udf_remote_sparse(self):
+        self._test_multi_remote_call(
+            my_function, True, kwargs_fn=RpcTest._multi_kwargs_fn
+        )
+
+    @dist_init
+    def test_py_rref_args_sparse(self):
+        self._py_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 4,
+        )
+
+    @dist_init
+    def test_py_rref_args_user_share_sparse(self):
+        self._py_rref_args_user_share(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6,
+        )
+
+    @dist_init
+    def test_py_rpc_rref_args_sparse(self):
+        self._py_rpc_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6,
+        )
+
+    @dist_init
+    def test_nested_remote_sparse(self):
+        self._nested_remote(
+            nested_remote_sparse, build_sparse_tensor() + build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_nested_rref_sparse(self):
+        self._nested_rref(
+            nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_nested_rref_stress_sparse(self):
+        self._nested_rref_stress(
+            nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_my_parameter_server_sparse(self):
+        self._my_parameter_server(True)
+
+    # Test init_rpc without world_size argument
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_init_rpc(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc.shutdown()
+
+    # Dynamic RPC new ranks communicate with existing ranks
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_new_rank_can_communicated_with_existing_rank(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            result = rpc.rpc_sync(
+                worker_name(0), torch.add, args=(torch.tensor(1), torch.tensor(1))
+            )
+            self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    # Dynamic RPC existing ranks can communicate with new ranks
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        # Rest of ranks join after barrier
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        dist.barrier()
+        if self.rank == 0:
+            for i in range(1, self.world_size):
+                result = rpc.rpc_sync(
+                    worker_name(i), torch.add, args=(torch.tensor(1), torch.tensor(1))
+                )
+                self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    # Dynamic RPC existing ranks can communicate with new ranks using CUDA rpc
+    @skip_if_lt_x_gpu(2)
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank_cuda(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            for i in range(1, self.world_size):
+                dst = worker_name(i)
+                options.set_device_map(dst, {1: 0})
+                options.set_device_map(dst, {0: 1})
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        # Rest of ranks join after barrier
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # TODO: Cuda RPC is failing due to:
+        # terminate called after throwing an instance of 'c10::Error'
+        # what():  0 <= device && static_cast(device) < device_allocator.size()
+        # INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1937,
+        # please report a bug to PyTorch. Allocator not initialized for device 1: did you call init?
+        # dist.barrier()
+        # if self.rank == 0:
+        #     for i in range(1, self.world_size):
+        #         x = torch.ones(2)
+        #         result_on_device_0 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(0), 1))
+        #         result_on_device_1 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(1), 1))
+        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_0)
+        #         self.assertEqual(torch.device('cuda:0'), result_on_device_0.device)
+        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_1)
+        #         self.assertEqual(torch.device('cuda:1'), result_on_device_1.device)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_init_rpc_without_rank(self):
+        # default initialization uses file init
+        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # env init
+        with self.assertRaisesRegex(ValueError, "environment variable RANK expected"):
+            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="env://")
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+        # tcp init
+        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
+            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+                init_method="tcp://127.0.0.1:23456"
+            )
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_dynamic_and_static_init_rpc_together(self):
+        # Initialize a static rpc group with size = self.world_size - 1
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.file_init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        world_size_minus_one = self.world_size - 1
+        if self.rank < world_size_minus_one:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=world_size_minus_one,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        dist.barrier()
+
+        # Attempt to add an additional dynamic group member
+        if self.rank == world_size_minus_one:
+            # Expect error message to be thrown
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "RPC group mixes statically and dynamically\
+ initialized members which is not supported.",
+            ):
+                rpc.init_rpc(
+                    name=worker_name(self.rank),
+                    backend=self.rpc_backend,
+                    rank=self.rank,
+                    rpc_backend_options=self.rpc_backend_options,
+                )
+
+
+class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon):
+    def _test_device_maps(self, options, errMsg):
+        with self.assertRaisesRegex(ValueError, errMsg):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+        self.assertFalse(rpc.api._is_current_rpc_agent_set())
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_wrong_worker_name(self):
+        options = self.rpc_backend_options
+        options.set_device_map("none_exist", {0: 1})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has invalid target node names in its device maps",
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_max_local_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {torch.cuda.device_count(): 0})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has source devices with invalid indices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_max_remote_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {0: torch.cuda.device_count()})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has target devices with invalid indices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_many_to_one(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {1: 0})
+        options.set_device_map(dst, {0: 0})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has duplicated target devices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_one_to_many(self):
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            dst = worker_name((self.rank + 1) % self.world_size)
+            options.set_device_map(dst, {0: 1})
+            with self.assertRaisesRegex(
+                ValueError, "`set_device_map` only supports 1-to-1 mapping"
+            ):
+                options.set_device_map(dst, {0: 0})
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_min_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"):
+            options.set_device_map(dst, {-1: 0})
+
+        with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"):
+            options.set_device_map(dst, {0: -1})
+
+    @staticmethod
+    def _gpu_add(x, y):
+        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]):
+            return (x + y).to(0)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_gpu(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {0: 1, 1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        ret = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add,
+            args=(torch.zeros(2).to(0), torch.ones(2).to(0)),
+        )
+        self.assertEqual(ret.device, torch.device(1))
+        self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1))
+        rpc.shutdown()
+
+    @staticmethod
+    def _gpu_add_given_devices(x, y, x_to, y_to, z_to):
+        x_device = "cpu" if x.device.type == "cpu" else x.device.index
+        y_device = "cpu" if y.device.type == "cpu" else y.device.index
+        if x_device == x_to and y_device == y_to:
+            return x.to(z_to) + y.to(z_to)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    def _test_device_maps_gpu(
+        self, x_from, y_from, z_to, device_map, dst=None, fn=None
+    ):
+        fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn
+        x_to = device_map[x_from]
+        y_to = device_map[y_from]
+
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst
+        options.set_device_map(dst, device_map)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(x_from)
+        y = torch.ones(2).to(y_from)
+
+        ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to))
+
+        reverse_device_map = {device_map[k]: k for k in device_map}
+        z_from = reverse_device_map[z_to]
+
+        ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index
+        self.assertEqual(ret_device, z_from)
+        self.assertEqual(ret, torch.ones(2).to(z_from))
+
+        rpc.shutdown()
+
+    def test_device_map_cpu(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to="cpu",
+            device_map={"cpu": "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_cpu_to_gpu_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=0,
+            device_map={"cpu": 0},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_cpu_to_gpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=1,
+            device_map={"cpu": 1},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_gpu_to_cpu_default(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=0,
+            z_to="cpu",
+            device_map={0: "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_to_cpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=1,
+            z_to="cpu",
+            device_map={1: "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_default(self):
+        self._test_device_maps_gpu(x_from=0, y_from=0, z_to=0, device_map={0: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_non_default(self):
+        self._test_device_maps_gpu(x_from=1, y_from=1, z_to=1, device_map={1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_default_to_non_default(self):
+        self._test_device_maps_gpu(x_from=0, y_from=0, z_to=1, device_map={0: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_non_default_to_default(self):
+        self._test_device_maps_gpu(x_from=1, y_from=1, z_to=0, device_map={1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_1(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_2(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_3(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_4(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_5(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_6(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_7(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_8(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_1(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=0,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_2(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=1,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_3(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=0,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_4(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=1,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_5(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=0,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_6(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=1,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_7(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=0,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_8(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=1,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @staticmethod
+    def _gpu_add_multi_gpu(x, y):
+        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]):
+            return x.to(0) + y, x - y.to(1)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    def _test_device_maps_multi_gpu(self, dst):
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 1})
+        options.set_device_map(dst, {1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(0)
+        y = torch.ones(2).to(1)
+        rets = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, args=(x, y)
+        )
+
+        self.assertEqual(rets[0].device, torch.device(1))
+        self.assertEqual(rets[1].device, torch.device(0))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_multi_gpu(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._test_device_maps_multi_gpu(dst)
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_multi_gpu_self(self):
+        dst = worker_name(self.rank)
+        self._test_device_maps_multi_gpu(dst)
+
+    @staticmethod
+    def _gpu_add_return_to_gpu(x, y):
+        if x.device.type == "cpu" and y.device.type == "cpu":
+            return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_in_options(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
+                init_method=options.init_method,
+                num_worker_threads=options.num_worker_threads,
+                device_maps={dst: {0: 1, 1: 0}},
+                _transports=tp_transports(),
+            ),
+        )
+
+        rets = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu,
+            args=(torch.zeros(2).to(0), torch.ones(2).to(1)),
+        )
+        self.assertEqual(rets[0].device, torch.device(1))
+        self.assertEqual(rets[1].device, torch.device(0))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        rpc.shutdown()
+
+    def _test_device_maps_return_to_gpu(self, dst):
+        options = self.rpc_backend_options
+
+        options.set_device_map(dst, {0: 1})
+        options.set_device_map(dst, {1: 2})
+        options.set_device_map(dst, {2: 3})
+        options.set_device_map(dst, {3: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rets = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add_return_to_gpu,
+            args=(torch.zeros(2), torch.ones(2)),
+        )
+        for i in range(len(rets)):
+            self.assertEqual(rets[i].device, torch.device((3 + i) % 4))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1))
+        self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2))
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_return_to_gpu(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._test_device_maps_return_to_gpu(dst)
+
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_return_to_gpu_self(self):
+        dst = worker_name(self.rank)
+        self._test_device_maps_return_to_gpu(dst)
+
+    @staticmethod
+    def _add_to_gpu(x, y):
+        return (x + y).to(0)
+
+    def _test_device_maps_missing_config(self, mode):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        errMsg = (
+            "TensorPipe RPC backend only supports CPU tensors by default.*"
+            "`set_device_map` on `TensorPipeRpcBackendOptions`"
+        )
+
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            if mode == RPCExecMode.SYNC:
+                rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1))
+            elif mode == RPCExecMode.REMOTE:
+                rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here()
+            else:
+                raise ValueError(f"unexpected mode {mode}")
+
+        # make sure RPC is still functioning
+        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
+        self.assertEqual(ret, torch.ones(2) + 1)
+
+    def _test_device_maps_missing_config_response(self, mode):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        errMsg = "Response device mapping is not available"
+
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            if mode == RPCExecMode.SYNC:
+                rpc.rpc_sync(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._add_to_gpu,
+                    args=(torch.zeros(2), 1),
+                )
+            elif mode == RPCExecMode.REMOTE:
+                rpc.remote(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._add_to_gpu,
+                    args=(torch.zeros(2), 1),
+                ).to_here()
+            else:
+                raise ValueError(f"unexpected mode {mode}")
+
+        # make sure RPC is still functioning
+        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
+        self.assertEqual(ret, torch.ones(2) + 1)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config(self):
+        self._test_device_maps_missing_config(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_missing_config_not_timeout(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        timeout = rpc.get_rpc_timeout()
+
+        tik = time.time()
+        self._test_device_maps_missing_config(RPCExecMode.SYNC)
+        rpc.shutdown()
+        tok = time.time()
+
+        self.assertTrue(tok - tik < timeout)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_loop(self):
+        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
+            self._test_device_maps_missing_config(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_response(self):
+        self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_response_loop(self):
+        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
+            self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_remote(self):
+        self._test_device_maps_missing_config(RPCExecMode.REMOTE)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_remote_response(self):
+        self._test_device_maps_missing_config_response(RPCExecMode.REMOTE)
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_remote(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rref = rpc.remote(
+            dst, TensorPipeAgentCudaRpcTest._add_to_gpu, args=(torch.zeros(2), 1)
+        )
+
+        self.assertEqual(rref.to_here().device.index, 1)
+        self.assertEqual(rref.to_here(), torch.ones(2).to(1))
+
+        rpc.shutdown()
+
+    @staticmethod
+    def _slow_add_on_user_stream(x, y):
+        s0 = torch.cuda.current_stream(x.device)
+        s1 = torch.cuda.Stream(device=x.device)
+        s1.wait_stream(s0)
+        x.record_stream(s1)
+        y.record_stream(s1)
+        with torch.cuda.stream(s1):
+            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+            z = x + y
+        s0.wait_stream(s1)
+        z.record_stream(s0)
+        return z
+
+    def _test_custom_stream(self, fn, device_map):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, device_map)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        fn(dst)
+
+        rpc.shutdown()
+
+    def _test_stream_sync(self, dst):
+        x = torch.ones(2, 2).to(0)
+        ret = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, x)
+        )
+        self.assertEqual(ret, 2 * x)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream(self):
+        self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"})
+
+    def _test_stream_multi_async(self, dst):
+        futs = []
+        for i in range(20):
+            x = torch.ones(2, 2).to(0) * i
+            futs.append(
+                rpc.rpc_async(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._slow_add_on_user_stream,
+                    args=(x, x),
+                )
+            )
+
+        for i in range(20):
+            self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_multi(self):
+        self._test_custom_stream(self._test_stream_multi_async, {"cuda:0": "cuda:1"})
+
+    @staticmethod
+    def _nested_slow_add_on_user_stream(dst, x, y, z):
+        ret = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, y)
+        )
+
+        return TensorPipeAgentCudaRpcTest._slow_add_on_user_stream(ret, z)
+
+    def _test_stream_nested_sync(self, dst):
+        x = torch.ones(2, 2).to(0)
+        y = torch.ones(2, 2).to(0) * 2
+        z = torch.ones(2, 2).to(0) * 3
+        nested_dst = worker_name((self.rank + 2) % self.world_size)
+        ret = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
+            args=(nested_dst, x, y, z),
+        )
+        self.assertEqual(ret, 6 * x)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_nested(self):
+        self._test_custom_stream(
+            self._test_stream_nested_sync, {"cuda:0": "cuda:1", "cuda:1": "cuda:0"}
+        )
+
+    def _test_stream_nested_multi_async(self, dst):
+        if self.rank == 0:
+            futs = []
+            n = 5
+            xs, ys, zs = [], [], []
+            for i in range(n):
+                x = torch.ones(2, 2).to(0) * (i - 1)
+                y = torch.ones(2, 2).to(0) * i
+                z = torch.ones(2, 2).to(0) * (i + 1)
+                xs.append(x)
+                ys.append(y)
+                zs.append(z)
+                nested_dst = worker_name((self.rank + 2) % self.world_size)
+                futs.append(
+                    rpc.rpc_async(
+                        dst,
+                        TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
+                        args=(nested_dst, x, y, z),
+                    )
+                )
+
+            for i in range(n):
+                self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i])
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_nested_multi(self):
+        self._test_custom_stream(
+            self._test_stream_nested_multi_async,
+            {"cuda:0": "cuda:1", "cuda:1": "cuda:0"},
+        )
+
+    @staticmethod
+    def _gpu_add_wrong_gpus(x, y):
+        if x.is_cuda and y.is_cuda:
+            return x.cpu() + y.cuda()
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_mismatch(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(0)
+        y = torch.ones(2).to(0)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Expected all tensors to be on the same device, but found at least two devices",
+        ):
+            rpc.rpc_sync(
+                dst, TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus, args=(x, y)
+            )
+
+        rpc.shutdown()
+
+    def _test_rref_synchronization(self, local_device, remote_device):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {local_device: remote_device})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 1:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                x = torch.randn(200, 1, 28, 28).to(local_device)
+                actual = rref.remote().forward(x).to_here()
+                expected = rref.rpc_sync().forward(x)
+                self.assertEqual(actual, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_to_here_synchronization1(self):
+        self._test_rref_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization2(self):
+        self._test_rref_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization3(self):
+        self._test_rref_synchronization("cuda:1", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization4(self):
+        self._test_rref_synchronization("cuda:0", "cuda:1")
+
+    def _test_rref_as_arg_synchronization(
+        self, local_device, remote_device, devicesOptions=None
+    ):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {local_device: remote_device})
+
+        input_src = worker_name((self.rank - 1 + self.world_size) % self.world_size)
+        options.set_device_map(input_src, {remote_device: local_device})
+
+        if devicesOptions is not None:
+            options.set_devices(devicesOptions[self.rank])
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 1:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                rref_x = RRef(torch.randn(200, 1, 28, 28).to(local_device))
+                actual = rref.remote().forward(rref_x, True).to_here()
+                expected = rref.rpc_sync().forward(rref_x, True)
+                self.assertEqual(actual, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_as_arg_synchronization1(self):
+        self._test_rref_as_arg_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization2(self):
+        self._test_rref_as_arg_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization3(self):
+        self._test_rref_as_arg_synchronization("cuda:1", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization4(self):
+        self._test_rref_as_arg_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_as_arg_synchronization5(self):
+        self._test_rref_as_arg_synchronization(
+            "cuda:0",
+            "cuda:0",
+            [["cuda:0"] for _ in range(4)],  # devicesOptions
+        )
+
+    @staticmethod
+    def _rref_relay(rref):
+        return rref.to_here()
+
+    def _test_rref_forward_synchronization(self, local_device, remote_device):
+        options = self.rpc_backend_options
+
+        input_src = worker_name(0)
+        model_dst = worker_name(1)
+        out_relay = worker_name(2)
+
+        if self.rank == 0:
+            # for 1) model construction 2) forward execution
+            options.set_device_map(model_dst, {local_device: remote_device})
+
+            # Forward output will be first copied to the relay node before
+            # returning to the worker. This is intentional, to test RRef
+            # forward CUDA stream synchronizations.
+            options.set_device_map(out_relay, {local_device: local_device})
+        elif self.rank == 1:
+            # worker1 hosts the model and runs forward. The forward functions
+            # calls RRef.to_here(), hence needs to configure the device map
+            options.set_device_map(input_src, {remote_device: local_device})
+        elif self.rank == 2:
+            # worker2 will get the out RRef and call to_here() and hence, needs
+            # to configure device map.
+            options.set_device_map(model_dst, {local_device: remote_device})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 0:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(model_dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                rref_input = RRef(torch.randn(200, 1, 28, 28).to(local_device))
+                rref_out = rref.remote().forward(rref_input, True)
+                out = rpc.remote(
+                    out_relay, TensorPipeAgentCudaRpcTest._rref_relay, args=(rref_out,)
+                ).to_here()
+                expected = rref.rpc_sync().forward(rref_input, True)
+                self.assertEqual(out, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_forward_synchronization1(self):
+        self._test_rref_forward_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization2(self):
+        self._test_rref_forward_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization3(self):
+        self._test_rref_forward_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization4(self):
+        self._test_rref_forward_synchronization("cuda:1", "cuda:1")
+
+    def _test_owner_rref_forward_synchronization(self, local_device, remote_device):
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            options.set_device_map("w0", {local_device: remote_device})
+            rpc.init_rpc("w0", rank=0, world_size=1, rpc_backend_options=options)
+
+            model = (
+                rpc.remote("w0", torch.nn.Linear, (2048, 20000))
+                .remote()
+                .to(remote_device)
+            )
+            for _ in range(30):
+                data = torch.rand(2048, 2048).to(local_device)
+                output = model.rpc_sync().forward(data)
+                # to_here() internally calls localValue as the caller is
+                # the owner of the RRef.
+                v0 = rpc.RRef(output).remote().sum().to_here().item()
+                v1 = output.sum().item()
+                self.assertEqual(v0, v1)
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_owner_rref_forward_synchronization1(self):
+        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization2(self):
+        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization3(self):
+        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization4(self):
+        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1")
+
+    @staticmethod
+    def _return_tensor_view(i):
+        x = torch.ones(1000, 200).cuda(0) * i
+        torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+        # serialization of the return value will create a new tensor from the
+        # view, which is done outside of the user function.
+        return x.split(100)[0]
+
+    @skip_if_lt_x_gpu(1)
+    def test_tensor_view_as_return_value(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        futs = [
+            rpc.rpc_async(
+                dst, TensorPipeAgentCudaRpcTest._return_tensor_view, args=(i,)
+            )
+            for i in range(5)
+        ]
+
+        for i in range(5):
+            self.assertEqual(torch.ones(100, 200) * i, futs[i].wait())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_devices_option_mismatch(self):
+        with self.assertRaisesRegex(
+            ValueError,
+            "Node worker0 has unexpected source devices in its device map for worker1",
+        ):
+            dst = worker_name((self.rank + 1) % self.world_size)
+            options = self.rpc_backend_options
+            options.set_device_map(dst, {0: 0})
+            options.set_devices([1])
+
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_devices_option_mismatch_reverse(self):
+        with self.assertRaisesRegex(
+            ValueError,
+            "Node worker0 has unexpected target devices in its device map for worker1",
+        ):
+            dst = worker_name((self.rank + 1) % self.world_size)
+
+            options = rpc.TensorPipeRpcBackendOptions(
+                init_method=self.rpc_backend_options.init_method,
+                num_worker_threads=self.rpc_backend_options.num_worker_threads,
+                device_maps={dst: {0: 1}},
+                devices=[0],
+            )
+
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_int(self):
+        Future(devices=[0])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_str(self):
+        Future(devices=["cuda:0"])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_device(self):
+        Future(devices=[torch.device("cuda", 0)])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_not_cuda(self):
+        with self.assertRaisesRegex(
+            ValueError, "Expected devices to have indices, got cpu"
+        ):
+            Future(devices=["cpu"])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_list_with_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_cuda_future_callback_changes_devices(self):
+        # We check proper CUDA stream synchronization by filling the tensor with
+        # the expected value in one stream, and reading it from another stream.
+        tensor0 = torch.zeros((100,), device="cuda:0")
+        tensor1 = torch.zeros((100,), device="cuda:1")
+        parent_future = Future(devices=["cuda:0", "cuda:1"])
+
+        def cb(fut):
+            t0 = fut.value()
+            tensor1.copy_(t0, non_blocking=True)
+            return tensor1
+
+        child_future = parent_future.then(cb)
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor0.fill_(1)
+                parent_future.set_result(tensor0)
+        with torch.cuda.device("cuda:1"):
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(another_stream):
+                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
+
+    @skip_if_lt_x_gpu(2)
+    def test_cuda_future_value_on_bad_device(self):
+        tensor0 = torch.zeros((100,), device="cuda:0")
+        tensor1 = torch.zeros((100,), device="cuda:1")
+        parent_future = Future(devices=["cuda:1"])
+
+        # As a plus, we test that futures still invoke callbacks even in case of
+        # error, and that the child futures are successful if those callbacks
+        # don't access the parent future.
+        def cb(fut):
+            with torch.cuda.device("cuda:1"):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor1.fill_(1)
+                return tensor1
+
+        child_future = parent_future.then(cb)
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor0.fill_(1)
+                parent_future.set_result(tensor0)
+        with self.assertRaisesRegex(
+            ValueError,
+            r"The result contained tensors residing on device\(s\) cuda:0 "
+            r"which are not among the expected device\(s\) cuda:1",
+        ):
+            parent_future.wait()
+        with torch.cuda.device("cuda:1"):
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(another_stream):
+                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
+
+    @skip_if_lt_x_gpu(1)
+    def test_async_execution_with_cuda_future(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        t = torch.zeros((100,), device="cuda:0")
+        fut = rpc.rpc_async(dst, async_cuda_sleep_and_set_to_one, args=(t,))
+        another_stream = torch.cuda.Stream("cuda:0")
+        with torch.cuda.stream(another_stream):
+            self.assertTrue(torch.eq(fut.wait(), 1).all().item())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_async_execution_nested_with_cuda_future(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        nested_dst = worker_name((self.rank + 2) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        a = torch.ones((100,), device="cuda:0")
+        b = torch.ones((100,), device="cuda:0")
+        c = torch.ones((100,), device="cuda:0")
+        fut = rpc.rpc_async(dst, async_cuda_nested_add, args=(nested_dst, a, b, c))
+        another_stream = torch.cuda.Stream("cuda:0")
+        with torch.cuda.stream(another_stream):
+            self.assertTrue(torch.eq(fut.wait(), 3).all().item())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_modify_tensor_inplace(self):
+        tensor = torch.zeros((100,), device="cuda:0")
+        future = Future(devices=["cuda:0"])
+        future.set_result(tensor)
+        # It's weird to modify the value of a future once it's complete, but
+        # technically possible. Currently this is considered undefined behavior
+        # (in practice the future will ignore the modification and still
+        # synchronize with the original value). We could one day add logic to
+        # detect and warn or throw in such cases, but for now we just check that
+        # this doesn't crash.
+        tensor.fill_(1)
+        future.wait()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_replace_tensor(self):
+        tensor_list = [torch.zeros((100,), device="cuda:0")]
+        future = Future(devices=["cuda:0"])
+        future.set_result(tensor_list)
+        # It's weird to modify the value of a future once it's complete, but
+        # technically possible. Currently this is considered undefined behavior
+        # (in practice the future will ignore the modification and still
+        # synchronize with the original value). We could one day add logic to
+        # detect and warn or throw in such cases, but for now we just check that
+        # this doesn't crash.
+        # We set things up so that the original tensor contained in the list
+        # gets deleted once we replace it with the other one. This will
+        # invalidate any cached information held by the future.
+        tensor_list[0] = torch.ones((100,), device="cuda:0")
+        future.wait()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_with_unpickleable_attributes(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rref = rpc.remote(dst, TensorWrapper, args=(torch.zeros(42, device="cuda:0"),))
+        rref.rpc_sync().increase(1)
+        ret = rref.rpc_sync().sum()
+        self.assertEqual(ret, 42)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=True
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=True
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_custom_class_with_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=True
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..021ae60468009d2fd4fa947c90455d99c1c6d54e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py
@@ -0,0 +1,28 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed.rpc as rpc
+from torch.testing._internal.common_distributed import tp_transports
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture):
+    @property
+    def rpc_backend(self):
+        return rpc.backend_registry.BackendType["TENSORPIPE"]
+
+    @property
+    def rpc_backend_options(self):
+        return rpc.backend_registry.construct_rpc_backend_options(
+            self.rpc_backend, init_method=self.init_method, _transports=tp_transports()
+        )
+
+    def get_shutdown_error_regex(self):
+        # FIXME Once we consolidate the error messages returned by the
+        # TensorPipe agent put some more specific regex here.
+        error_regexes = [".*"]
+        return "|".join([f"({error_str})" for error_str in error_regexes])
+
+    def get_timeout_error_regex(self):
+        return "RPC ran for more than"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_pytree_test_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_pytree_test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..737b7d27a1561477c8a3781926453f90cf622c8c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_pytree_test_utils.py
@@ -0,0 +1,28 @@
+import torch
+import torch._dynamo.test_case
+import torch.utils._pytree as pytree
+
+
+class PytreeRegisteringTestCase(torch._dynamo.test_case.TestCase):
+    """TestCase that prunes all temporary pytree registrations and resets Dynamo."""
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._registered_pytree_nodes: list[type] = []
+        self._registered_constant_nodes: list[type] = []
+
+    def tearDown(self) -> None:
+        for cls in reversed(self._registered_pytree_nodes):
+            pytree._deregister_pytree_node(cls)
+        for cls in reversed(self._registered_constant_nodes):
+            pytree._deregister_pytree_node(cls)
+        torch._dynamo.reset()
+        super().tearDown()
+
+    def register_pytree_node(self, cls, *args, **kwargs) -> None:  # type: ignore[no-untyped-def]
+        pytree.register_pytree_node(cls, *args, **kwargs)
+        self._registered_pytree_nodes.append(cls)
+
+    def register_constant(self, cls: type) -> None:
+        pytree.register_constant(cls)
+        self._registered_constant_nodes.append(cls)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc69b7920cf06d24dceac0bb2743004c0b6c64e
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py
@@ -0,0 +1,145 @@
+"""
+This file contains the list of tests that are known to fail under Dynamo
+
+We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures`
+We generate skipIfTorchDynamo* for all tests in `dynamo_skips`
+We generate runWithoutCompiledAutograd for all tests in `compiled_autograd_skips`
+
+For an easier-than-manual way of generating and updating these lists,
+see scripts/compile_tests/update_failures.py
+
+If you're adding a new test, and it's failing PYTORCH_TEST_WITH_DYNAMO=1,
+either add the appropriate decorators to your test or add skips for them
+via test/dynamo_skips and test/dynamo_expected_failures.
+
+*These are not exactly unittest.expectedFailure and unittest.skip. We'll
+always execute the test and then suppress the signal, if necessary.
+If your tests crashes, or is slow, please use @skipIfTorchDynamo instead.
+
+The expected failure and skip files are located in test/dynamo_skips and
+test/dynamo_expected_failures. They're individual files rather than a list so
+git will merge changes easier.
+"""
+
+import logging
+import os
+import sys
+from typing import Optional
+
+
+def find_test_dir() -> Optional[str]:
+    # Find the path to the dynamo expected failure and skip files.
+    from os.path import abspath, basename, dirname, exists, join, normpath
+
+    if sys.platform == "win32":
+        return None
+
+    # Check relative to this file (local build):
+    test_dir = normpath(join(dirname(abspath(__file__)), "../../../test"))
+    if exists(join(test_dir, "dynamo_expected_failures")):
+        return test_dir
+
+    # Check relative to __main__ (installed builds relative to test file):
+    main = sys.modules["__main__"]
+    file = getattr(main, "__file__", None)
+    if file is None:
+        # Generated files do not have a module.__file__
+        return None
+    test_dir = dirname(abspath(file))
+    while dirname(test_dir) != test_dir:
+        if basename(test_dir) == "test" and exists(
+            join(test_dir, "dynamo_expected_failures")
+        ):
+            return test_dir
+        test_dir = dirname(test_dir)
+
+    # Not found
+    return None
+
+
+test_dir = find_test_dir()
+if not test_dir:
+    logger = logging.getLogger(__name__)
+    logger.warning(
+        "test/dynamo_expected_failures directory not found - known dynamo errors won't be skipped."
+    )
+
+# Tests that run without strict mode in PYTORCH_TEST_WITH_INDUCTOR=1.
+# Please don't add anything to this list.
+FIXME_inductor_non_strict = {
+    "test_modules",
+    "test_ops",
+    "test_ops_gradients",
+    "test_torch",
+}
+
+# We generate unittest.expectedFailure for all of the following tests
+# when run under PYTORCH_TEST_WITH_DYNAMO=1.
+# see NOTE [dynamo_test_failures.py] for more details
+#
+# This lists exists so we can more easily add large numbers of failing tests,
+if test_dir is None:
+    dynamo_expected_failures = set()
+    dynamo_skips = set()
+
+    inductor_expected_failures = set()
+    inductor_skips = set()
+
+    compiled_autograd_skips = set()
+else:
+    dynamo_failures_directory = os.path.join(test_dir, "dynamo_expected_failures")
+    dynamo_skips_directory = os.path.join(test_dir, "dynamo_skips")
+
+    dynamo_expected_failures = set(os.listdir(dynamo_failures_directory))
+    dynamo_skips = set(os.listdir(dynamo_skips_directory))
+
+    inductor_failures_directory = os.path.join(test_dir, "inductor_expected_failures")
+    inductor_skips_directory = os.path.join(test_dir, "inductor_skips")
+
+    inductor_expected_failures = set(os.listdir(inductor_failures_directory))
+    inductor_skips = set(os.listdir(inductor_skips_directory))
+
+    compiled_autograd_skips_directory = os.path.join(
+        test_dir, "compiled_autograd_skips"
+    )
+    compiled_autograd_skips = set(os.listdir(compiled_autograd_skips_directory))
+
+# TODO: due to case sensitivity problems, for now list these files by hand
+extra_dynamo_skips = {
+    "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_t_cpu_float32",
+}
+dynamo_skips = dynamo_skips.union(extra_dynamo_skips)
+
+
+# verify some invariants
+for test in (
+    dynamo_expected_failures
+    | dynamo_skips
+    | inductor_expected_failures
+    | inductor_skips
+):
+    if len(test.split(".")) != 2:
+        raise AssertionError(f'Invalid test name: "{test}"')
+
+dynamo_intersection = dynamo_expected_failures.intersection(dynamo_skips)
+if len(dynamo_intersection) > 0:
+    raise AssertionError(
+        "there should be no overlap between dynamo_expected_failures "
+        "and dynamo_skips, got " + str(dynamo_intersection)
+    )
+
+inductor_intersection = inductor_expected_failures.intersection(inductor_skips)
+if len(inductor_intersection) > 0:
+    raise AssertionError(
+        "there should be no overlap between inductor_expected_failures "
+        "and inductor_skips, got " + str(inductor_intersection)
+    )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff4118438e74cd2354997b0f3a76c4d59370b8bc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py
@@ -0,0 +1,11 @@
+import sys
+from typing import Callable, Optional  # noqa: UP035
+
+from torch.utils._config_module import install_config_module
+
+
+e_list = [1]
+e_set = {1}
+e_func: Optional[Callable] = None
+
+install_config_module(sys.modules[__name__])
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8d707d22ab81a6c191283a09ef9bfd54ae80cdc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py
@@ -0,0 +1,513 @@
+# mypy: ignore-errors
+
+import functools
+import unittest
+
+import torch
+from functorch.experimental.control_flow import map
+from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import onlyCUDA
+from torch.testing._internal.common_dtype import all_types_and, custom_types
+from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput
+from torch._higher_order_ops.invoke_subgraph import mark_compile_region
+from torch._higher_order_ops import InvokeQuant, invoke_quant_packed
+
+
+def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(
+        [make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
+        args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)),
+    )
+
+
+def inner_f(x, y0, y1):
+    return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
+
+
+def simple_map(xs, y0, y1):
+    def f(x, y0, y1):
+        return inner_f(x, y0, y1)
+
+    return map(f, xs, y0, y1)
+
+
+def nested_map(xs, y0, y1):
+    def f1(xx, y0, y1):
+        def f2(x, y0, y1):
+            return inner_f(x, y0, y1)
+
+        return map(f2, xx, y0, y1)
+
+    return map(f1, xs, y0, y1)
+
+
+def triple_nested_map(xs, y0, y1):
+    def f0(xs, y0, y1):
+        def f1(xx, y0, y1):
+            def f2(x, y0, y1):
+                return inner_f(x, y0, y1)
+
+            return map(f2, xx, y0, y1)
+
+        return map(f1, xs, y0, y1)
+
+    return map(f0, xs, y0, y1)
+
+
+# PLEASE DON'T ADD ANYTHING NEW TO THIS LIST,
+# and do add an OpInfo for your HOP.
+# The OpInfo lets us do automated testing for the HOP to check that
+# your HOP will work correctly with PyTorch!
+#
+# Your new HOP may fail some automated testing. That's OK. If you don't
+# care about certain features (like torch.export), it's fine to xfail those
+# failing tests. It is less fine to xfail a more critical check (like checking
+# if torch.compile works with your HOP, or if your HOP has a docstring).
+# If you don't know if a test is fine to xfail, please ask.
+#
+# There are legitimate reasons why something cannot be added to this list
+# (e.g. it uses executorch which is not in PyTorch). If that's the case then
+# please leave a comment.
+FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
+    "custom_function_call",
+    "autograd_function_apply",
+    "run_and_save_rng_state",
+    "run_with_rng_state",
+    "graphsafe_run_with_rng_state",
+    "out_dtype",
+    "trace_wrapped",
+    'tag_activation_checkpoint',
+    'executorch_call_delegate',
+    'wrap',
+    'wrap_with_set_grad_enabled',
+    'auto_functionalized_v2',
+    'associative_scan',
+    'flat_apply',  # is WIP, doesn't pass any of the tests yet
+    'wrap_with_autocast',
+    'wrap_activation_checkpoint',
+    'run_const_graph',
+    'auto_functionalized',
+    "map",  # T183144629
+    "map_impl",
+    "with_effects",
+    "strict_mode",
+    "_export_tracepoint",
+    "call_torchbind",
+    "triton_kernel_wrapper_mutation",
+    "triton_kernel_wrapper_functional",
+    "hints_wrapper",
+    "dynamo_bypassing_wrapper",  # TODO(soulitzer)
+    "foreach_map",
+    "aoti_call_delegate",
+    "print",
+    "inductor_compiled_code",  # Tested separately in test_inductor_wrap_inductor_compile_regions
+]
+
+torch.library.define(
+    "testlib::mutating_custom_op",
+    "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
+    tags=torch.Tag.pt2_compliant_tag,
+)
+
+
+@torch.library.impl("testlib::mutating_custom_op", "cpu")
+def foo_impl_cpu(x, z):
+    x.add_(5)
+    z.add_(5)
+    return x, z, x + z
+
+
+@torch.library.impl("testlib::mutating_custom_op", "cuda")
+def foo_impl_cuda(x, z):
+    x.add_(5)
+    z.add_(5)
+    return x, z, x + z
+
+
+@torch.library.register_fake("testlib::mutating_custom_op")
+def foo_impl_abstract(x, z):
+    return x, z, x + z
+
+
+def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
+
+
+def simple_cond(x):
+    return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
+
+
+def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
+
+
+@mark_compile_region
+def fn_for_invoke_subgraph(x):
+    return torch.sin(x)
+
+
+def simple_invoke_subgraph(x):
+    return fn_for_invoke_subgraph(x)
+
+
+def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)
+    )
+
+
+def simple_auto_functionalize(x, z):
+    return torch.ops.testlib.mutating_custom_op(x, z)
+
+
+def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    def score_mod(score, b, h, m, n):
+        return score + h
+
+    q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3))
+    block_mask = _create_empty_block_mask(q, k)
+    yield SampleInput(q, k, v, score_mod, block_mask)
+
+
+def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        torch.tensor(3),
+        make_arg(2, 3, 4, low=0.1, high=2),
+    )
+
+
+def simple_while_loop(iter_t, x):
+    def cond_fn(iter_t, x):
+        return iter_t > 0
+
+    def body_fn(iter_t, x):
+        return iter_t - 1, x.cos()
+
+    return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
+
+
+def simple_while_loop_stack_output(iter_t, x):
+    def cond_fn(iter_t, x):
+        return iter_t > 0
+
+    def body_fn(iter_t, x):
+        return iter_t - 1, x.cos()
+
+    return torch._higher_order_ops.while_loop_stack_output(
+        cond_fn, body_fn, (iter_t, x), tuple()
+    )
+
+
+def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
+    # TODO: once HOPs support DTensor inputs, we should also test DTensors
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        make_arg(2, 3, 4, low=0.1, high=2),
+        make_arg(2, 3, 4, low=0.1, high=2),
+    )
+
+
+def simple_local_map_hop(inp1, inp2):
+    def body_gm(inp1, inp2):
+        return inp1.cos() + inp2.sin()
+
+    gm = torch.fx.symbolic_trace(body_gm)
+
+    assert torch.distributed.is_available()
+    from torch.distributed.tensor.placement_types import Replicate
+
+    gm.meta["local_map_kwargs"] = {
+        "in_placements": (Replicate(), Replicate(), Replicate()),
+        "out_placements": ((Replicate(), Replicate(), Replicate()),),
+    }
+
+    # TODO: Dynamo would rewrite this op differently
+    return torch._higher_order_ops.local_map_hop(gm, inp1, inp2)
+
+
+def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(
+        make_arg(2, 2, low=0.1, high=2),
+        make_arg(2, 2, 2, low=0.1, high=2),
+    )
+
+
+def simple_scan(init, xs):
+    def combine_fn(carry, x):
+        result = carry @ x + x
+        return result, carry.clone()
+
+    return torch._higher_order_ops.scan(combine_fn, init, xs)
+
+
+quant_tracer = InvokeQuant()
+
+
+def simple_invoke_quant(x):
+    def fn(x, y):
+        return (torch.sin(x) * y,)
+
+    return quant_tracer(fn, x, x)[0] * 2.0
+
+
+def simple_invoke_quant_packed(x):
+    def fn(x):
+        return (torch.sin(x),)
+
+    return invoke_quant_packed(fn, x)[0] * 2.0
+
+
+hop_db = [
+    OpInfo(
+        name="scan",
+        variant_test_name="simple",
+        op=simple_scan,
+        sample_inputs_func=sample_inputs_scan,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_subgraph",
+        variant_test_name="simple",
+        op=simple_invoke_subgraph,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="simple",
+        op=simple_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="nested",
+        op=nested_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="triple_nested",
+        op=triple_nested_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="cond",
+        variant_test_name="simple",
+        op=simple_cond,
+        sample_inputs_func=sample_inputs_cond,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_quant",
+        variant_test_name="simple",
+        op=simple_invoke_quant,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_quant_packed",
+        variant_test_name="simple",
+        op=simple_invoke_quant_packed,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="while_loop",
+        variant_test_name="simple",
+        op=simple_while_loop,
+        sample_inputs_func=sample_inputs_while_loop,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="while_loop_stack_output",
+        variant_test_name="simple",
+        op=simple_while_loop_stack_output,
+        sample_inputs_func=sample_inputs_while_loop,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="auto_functionalize",
+        variant_test_name="simple",
+        op=simple_auto_functionalize,
+        sample_inputs_func=sample_inputs_auto_functionalize,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="flex_attention",
+        variant_test_name="simple",
+        op=flex_attention,
+        sample_inputs_func=sample_inputs_flex_attention,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[onlyCUDA],
+    ),
+    OpInfo(
+        name="flex_attention_backward",
+        variant_test_name="simple",
+        op=flex_attention,
+        sample_inputs_func=sample_inputs_flex_attention,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[onlyCUDA],
+    ),
+    OpInfo(
+        name="local_map_hop",
+        variant_test_name="simple",
+        op=simple_local_map_hop,
+        sample_inputs_func=sample_inputs_local_map_hop,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[
+            onlyCUDA,
+            unittest.skipIf(
+                not torch.distributed.is_available(), "requires distributed build"
+            ),
+        ],
+    ),
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a00e1e1a048a0e12c3e081da4415a980cfd97608
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py
@@ -0,0 +1,379 @@
+# mypy: ignore-errors
+
+from collections import defaultdict
+from collections.abc import Iterable
+import numpy as np
+import torch
+
+import hypothesis
+from functools import reduce
+from importlib.metadata import version
+from hypothesis import assume
+from hypothesis import settings
+from hypothesis import strategies as st
+from hypothesis.extra import numpy as stnp
+from hypothesis.strategies import SearchStrategy
+
+from torch.testing._internal.common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams
+
+# Setup for the hypothesis tests.
+# The tuples are (torch_quantized_dtype, zero_point_enforce), where the last
+# element is enforced zero_point. If None, any zero_point point within the
+# range of the data type is OK.
+
+# Tuple with all quantized data types.
+_ALL_QINT_TYPES = (
+    torch.quint8,
+    torch.qint8,
+    torch.qint32,
+)
+
+# Enforced zero point for every quantized data type.
+# If None, any zero_point point within the range of the data type is OK.
+_ENFORCED_ZERO_POINT = defaultdict(lambda: None, {
+    torch.quint8: None,
+    torch.qint8: None,
+    torch.qint32: 0
+})
+
+def _get_valid_min_max(qparams):
+    scale, zero_point, _quantized_type = qparams
+    adjustment = 1 + torch.finfo(torch.float).eps
+    _long_type_info = torch.iinfo(torch.long)
+    long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment
+    # make sure intermediate results are within the range of long
+    min_value = max((long_min - zero_point) * scale, (long_min / scale + zero_point))
+    max_value = min((long_max - zero_point) * scale, (long_max / scale + zero_point))
+    return np.float32(min_value), np.float32(max_value)
+
+# This wrapper wraps around `st.floats` and checks the version of `hypothesis`, if
+# it is too old, removes the `width` parameter (which was introduced)
+# in 3.67.0
+def _floats_wrapper(*args, **kwargs):
+    if 'width' in kwargs and hypothesis.version.__version_info__ < (3, 67, 0):
+        # As long as nan, inf, min, max are not specified, reimplement the width
+        # parameter for older versions of hypothesis.
+        no_nan_and_inf = (
+            (('allow_nan' in kwargs and not kwargs['allow_nan']) or
+             'allow_nan' not in kwargs) and
+            (('allow_infinity' in kwargs and not kwargs['allow_infinity']) or
+             'allow_infinity' not in kwargs))
+        min_and_max_not_specified = (
+            len(args) == 0 and
+            'min_value' not in kwargs and
+            'max_value' not in kwargs
+        )
+        if no_nan_and_inf and min_and_max_not_specified:
+            if kwargs['width'] == 16:
+                kwargs['min_value'] = torch.finfo(torch.float16).min
+                kwargs['max_value'] = torch.finfo(torch.float16).max
+            elif kwargs['width'] == 32:
+                kwargs['min_value'] = torch.finfo(torch.float32).min
+                kwargs['max_value'] = torch.finfo(torch.float32).max
+            elif kwargs['width'] == 64:
+                kwargs['min_value'] = torch.finfo(torch.float64).min
+                kwargs['max_value'] = torch.finfo(torch.float64).max
+        kwargs.pop('width')
+    return st.floats(*args, **kwargs)
+
+def floats(*args, **kwargs):
+    if 'width' not in kwargs:
+        kwargs['width'] = 32
+    return _floats_wrapper(*args, **kwargs)
+
+"""Hypothesis filter to avoid overflows with quantized tensors.
+
+Args:
+    tensor: Tensor of floats to filter
+    qparams: Quantization parameters as returned by the `qparams`.
+
+Returns:
+    True
+
+Raises:
+    hypothesis.UnsatisfiedAssumption
+
+Note: This filter is slow. Use it only when filtering of the test cases is
+      absolutely necessary!
+"""
+def assume_not_overflowing(tensor, qparams):
+    min_value, max_value = _get_valid_min_max(qparams)
+    assume(tensor.min() >= min_value)
+    assume(tensor.max() <= max_value)
+    return True
+
+"""Strategy for generating the quantization parameters.
+
+Args:
+    dtypes: quantized data types to sample from.
+    scale_min / scale_max: Min and max scales. If None, set to 1e-3 / 1e3.
+    zero_point_min / zero_point_max: Min and max for the zero point. If None,
+        set to the minimum and maximum of the quantized data type.
+        Note: The min and max are only valid if the zero_point is not enforced
+              by the data type itself.
+
+Generates:
+    scale: Sampled scale.
+    zero_point: Sampled zero point.
+    quantized_type: Sampled quantized type.
+"""
+@st.composite
+def qparams(draw, dtypes=None, scale_min=None, scale_max=None,
+            zero_point_min=None, zero_point_max=None):
+    if dtypes is None:
+        dtypes = _ALL_QINT_TYPES
+    if not isinstance(dtypes, (list, tuple)):
+        dtypes = (dtypes,)
+    quantized_type = draw(st.sampled_from(dtypes))
+
+    _type_info = torch.iinfo(quantized_type)
+    qmin, qmax = _type_info.min, _type_info.max
+
+    # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`.
+    _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type]
+    if _zp_enforced is not None:
+        zero_point = _zp_enforced
+    else:
+        _zp_min = qmin if zero_point_min is None else zero_point_min
+        _zp_max = qmax if zero_point_max is None else zero_point_max
+        zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max))
+
+    if scale_min is None:
+        scale_min = torch.finfo(torch.float).eps
+    if scale_max is None:
+        scale_max = torch.finfo(torch.float).max
+    scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32))
+
+    return scale, zero_point, quantized_type
+
+"""Strategy to create different shapes.
+Args:
+    min_dims / max_dims: minimum and maximum rank.
+    min_side / max_side: minimum and maximum dimensions per rank.
+
+Generates:
+    Possible shapes for a tensor, constrained to the rank and dimensionality.
+
+Example:
+    # Generates 3D and 4D tensors.
+    @given(Q = qtensor(shapes=array_shapes(min_dims=3, max_dims=4))
+    some_test(self, Q):...
+"""
+@st.composite
+def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
+    """Return a strategy for array shapes (tuples of int >= 1)."""
+    assert min_dims < 32
+    if max_dims is None:
+        max_dims = min(min_dims + 2, 32)
+    assert max_dims < 32
+    if max_side is None:
+        max_side = min_side + 5
+    candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
+    if max_numel is not None:
+        candidate = candidate.filter(lambda x: reduce(int.__mul__, x, 1) <= max_numel)
+    return draw(candidate.map(tuple))
+
+
+"""Strategy for generating test cases for tensors.
+The resulting tensor is in float32 format.
+
+Args:
+    shapes: Shapes under test for the tensor. Could be either a hypothesis
+            strategy, or an iterable of different shapes to sample from.
+    elements: Elements to generate from for the returned data type.
+              If None, the strategy resolves to float within range [-1e6, 1e6].
+    qparams: Instance of the qparams strategy. This is used to filter the tensor
+             such that the overflow would not happen.
+
+Generates:
+    X: Tensor of type float32. Note that NaN and +/-inf is not included.
+    qparams: (If `qparams` arg is set) Quantization parameters for X.
+        The returned parameters are `(scale, zero_point, quantization_type)`.
+        (If `qparams` arg is None), returns None.
+"""
+@st.composite
+def tensor(draw, shapes=None, elements=None, qparams=None, dtype=np.float32):
+    if isinstance(shapes, SearchStrategy):
+        _shape = draw(shapes)
+    else:
+        _shape = draw(st.sampled_from(shapes))
+    if qparams is None:
+        if elements is None:
+            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
+        X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
+        assume(not (np.isnan(X).any() or np.isinf(X).any()))
+        return X, None
+    qparams = draw(qparams)
+    if elements is None:
+        min_value, max_value = _get_valid_min_max(qparams)
+        elements = floats(min_value, max_value, allow_infinity=False,
+                          allow_nan=False, width=32)
+    X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
+    # Recompute the scale and zero_points according to the X statistics.
+    scale, zp = _calculate_dynamic_qparams(X, qparams[2])
+    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
+    if enforced_zp is not None:
+        zp = enforced_zp
+    return X, (scale, zp, qparams[2])
+
+@st.composite
+def per_channel_tensor(draw, shapes=None, elements=None, qparams=None):
+    if isinstance(shapes, SearchStrategy):
+        _shape = draw(shapes)
+    else:
+        _shape = draw(st.sampled_from(shapes))
+    if qparams is None:
+        if elements is None:
+            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
+        X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
+        assume(not (np.isnan(X).any() or np.isinf(X).any()))
+        return X, None
+    qparams = draw(qparams)
+    if elements is None:
+        min_value, max_value = _get_valid_min_max(qparams)
+        elements = floats(min_value, max_value, allow_infinity=False,
+                          allow_nan=False, width=32)
+    X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
+    # Recompute the scale and zero_points according to the X statistics.
+    scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2])
+    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
+    if enforced_zp is not None:
+        zp = enforced_zp
+    # Permute to model quantization along an axis
+    axis = int(np.random.randint(0, X.ndim, 1))
+    permute_axes = np.arange(X.ndim)
+    permute_axes[0] = axis
+    permute_axes[axis] = 0
+    X = np.transpose(X, permute_axes)
+
+    return X, (scale, zp, axis, qparams[2])
+
+"""Strategy for generating test cases for tensors used in Conv.
+The resulting tensors is in float32 format.
+
+Args:
+    spatial_dim: Spatial Dim for feature maps. If given as an iterable, randomly
+                 picks one from the pool to make it the spatial dimension
+    batch_size_range: Range to generate `batch_size`.
+                      Must be tuple of `(min, max)`.
+    input_channels_per_group_range:
+        Range to generate `input_channels_per_group`.
+        Must be tuple of `(min, max)`.
+    output_channels_per_group_range:
+        Range to generate `output_channels_per_group`.
+        Must be tuple of `(min, max)`.
+    feature_map_range: Range to generate feature map size for each spatial_dim.
+                       Must be tuple of `(min, max)`.
+    kernel_range: Range to generate kernel size for each spatial_dim. Must be
+                  tuple of `(min, max)`.
+    max_groups: Maximum number of groups to generate.
+    elements: Elements to generate from for the returned data type.
+              If None, the strategy resolves to float within range [-1e6, 1e6].
+    qparams: Strategy for quantization parameters. for X, w, and b.
+             Could be either a single strategy (used for all) or a list of
+             three strategies for X, w, b.
+Generates:
+    (X, W, b, g): Tensors of type `float32` of the following drawen shapes:
+        X: (`batch_size, input_channels, H, W`)
+        W: (`output_channels, input_channels_per_group) + kernel_shape
+        b: `(output_channels,)`
+        groups: Number of groups the input is divided into
+Note: X, W, b are tuples of (Tensor, qparams), where qparams could be either
+      None or (scale, zero_point, quantized_type)
+
+
+Example:
+    @given(tensor_conv(
+        spatial_dim=2,
+        batch_size_range=(1, 3),
+        input_channels_per_group_range=(1, 7),
+        output_channels_per_group_range=(1, 7),
+        feature_map_range=(6, 12),
+        kernel_range=(3, 5),
+        max_groups=4,
+        elements=st.floats(-1.0, 1.0),
+        qparams=qparams()
+    ))
+"""
+@st.composite
+def tensor_conv(
+    draw, spatial_dim=2, batch_size_range=(1, 4),
+    input_channels_per_group_range=(3, 7),
+    output_channels_per_group_range=(3, 7), feature_map_range=(6, 12),
+    kernel_range=(3, 7), max_groups=1, can_be_transposed=False,
+    elements=None, qparams=None
+):
+
+    # Resolve the minibatch, in_channels, out_channels, iH/iW, iK/iW
+    batch_size = draw(st.integers(*batch_size_range))
+    input_channels_per_group = draw(
+        st.integers(*input_channels_per_group_range))
+    output_channels_per_group = draw(
+        st.integers(*output_channels_per_group_range))
+    groups = draw(st.integers(1, max_groups))
+    input_channels = input_channels_per_group * groups
+    output_channels = output_channels_per_group * groups
+
+    if isinstance(spatial_dim, Iterable):
+        spatial_dim = draw(st.sampled_from(spatial_dim))
+
+    feature_map_shape = [draw(st.integers(*feature_map_range)) for _ in range(spatial_dim)]
+
+    kernels = [draw(st.integers(*kernel_range)) for _ in range(spatial_dim)]
+
+    tr = False
+    weight_shape = (output_channels, input_channels_per_group) + tuple(kernels)
+    bias_shape = output_channels
+    if can_be_transposed:
+        tr = draw(st.booleans())
+        if tr:
+            weight_shape = (input_channels, output_channels_per_group) + tuple(kernels)
+            bias_shape = output_channels
+
+    # Resolve the tensors
+    if qparams is not None:
+        if isinstance(qparams, (list, tuple)):
+            assert len(qparams) == 3, "Need 3 qparams for X, w, b"
+        else:
+            qparams = [qparams] * 3
+
+    X = draw(tensor(shapes=(
+        (batch_size, input_channels) + tuple(feature_map_shape),),
+        elements=elements, qparams=qparams[0]))
+    W = draw(tensor(shapes=(weight_shape,), elements=elements,
+                    qparams=qparams[1]))
+    b = draw(tensor(shapes=(bias_shape,), elements=elements,
+                    qparams=qparams[2]))
+
+    return X, W, b, groups, tr
+
+
+# We set the deadline in the currently loaded profile.
+# Creating (and loading) a separate profile overrides any settings the user
+# already specified.
+hypothesis_version = tuple(map(int, version("hypothesis").split(".")[:3]))
+
+if (3, 16, 0) <= hypothesis_version < (3, 27, 0):
+    # Hypothesis 3.16 → 3.26: use `timeout` instead of `deadline`
+    settings.register_profile("no_deadline", timeout=hypothesis.unlimited)
+else:
+    # Hypothesis >=3.27: use `deadline=None`
+    settings.register_profile("no_deadline", deadline=None)
+
+# Activate the profile
+settings.load_profile("no_deadline")
+
+
+def assert_deadline_disabled():
+    """Check that deadlines are effectively disabled across Hypothesis versions."""
+    if hypothesis_version < (3, 27, 0):
+        import warnings
+
+        warning_message = (
+            "Your version of hypothesis is outdated. "
+            "To avoid `DeadlineExceeded` errors, please update. "
+            f"Current hypothesis version: {hypothesis.__version__}"
+        )
+        warnings.warn(warning_message, stacklevel=2)
+    else:
+        assert settings().deadline is None
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3dbb95f4ba9c3c430e27a677fc3850aee2b3549
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py
@@ -0,0 +1,725 @@
+# mypy: ignore-errors
+
+# Torch
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
+import torch.nn.functional as F
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+from torch.testing._internal.common_nn import module_tests, get_new_module_tests
+from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
+
+import collections
+from copy import deepcopy
+from typing import Any, Union
+import math  # noqa: F401
+
+# Testing utils
+from torch import inf
+
+assert torch.get_default_dtype() == torch.float32
+
+L = 20
+M = 10
+S = 5
+
+
+def unpack_variables(args):
+    if isinstance(args, tuple):
+        return tuple(unpack_variables(elem) for elem in args)
+    else:
+        return args
+
+class dont_convert(tuple):
+    __slots__ = ()
+
+non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
+
+def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.float, device=None):
+    if not isinstance(call_args, tuple):
+        call_args = (call_args,)
+
+    def map_arg(arg):
+        def maybe_non_contig(tensor):
+            if not non_contiguous or tensor.numel() < 2:
+                return tensor.clone()
+
+            return noncontiguous_like(tensor)
+
+        def conjugate(tensor):
+            return tensor.conj()
+
+        if isinstance(arg, (torch.Size, dont_convert)):
+            return arg
+        elif isinstance(arg, tuple) and len(arg) == 0:
+            var = conjugate(torch.randn((), dtype=dtype, device=device))
+            var.requires_grad = requires_grad
+            return var
+        elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
+            return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
+        # double check casting
+        elif isinstance(arg, non_differentiable):
+            if isinstance(arg.tensor, torch.Tensor):
+                return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
+            return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
+        elif isinstance(arg, torch.Tensor):
+            if arg.is_complex() != dtype.is_complex:
+                raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
+                                   "which is not supported for now")
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
+            v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
+            v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
+            return v
+        elif callable(arg):
+            return map_arg(arg(dtype=dtype, device=device))
+        else:
+            return arg
+    args_out = tuple(map_arg(arg) for arg in call_args)
+    kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {}
+    return args_out, kwargs_out
+
+# NB: JIT script tests for all nn functional interfaces, script mode does
+# not support in_place operations yet, so no inplace operation tests added.
+# removed all the deprecated functions
+#
+# (
+#   method name,
+#   input size/constructing fn,
+#   args (tuple represents shape of a tensor arg),
+#   test variant name(will be used at test name suffix,
+#       'inplace' skips grad tests),                         // optional
+#   (True, nonfusible_nodes, fusible_nodes) for autodiff     // optional
+#   fn to determine if test should be skipped,               // optional
+#   fn mapping output to part that should be gradcheck'ed,   // optional
+#   kwargs for function,                                     // optional
+# )
+def get_nn_functional_tests():
+    nn_functional_tests = [
+        ('conv1d', (S, S, S), ((S, S, S),)),
+        ('conv2d', (S, S, S, S), ((S, S, S, S),)),
+        ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
+        ('conv_transpose1d', (S, S, S), ((S, S, S),)),
+        ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
+        ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
+        ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
+        ('avg_pool1d', (S, S, S), (3,)),
+        ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
+        ('avg_pool3d', (S, S, S, S, S), (3,)),
+        ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
+        ('max_pool1d', (S, S, S), (2, 1)),
+        ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
+        ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
+        ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
+        ('max_pool3d', (S, S, S, S, S), (2, 1)),
+        ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
+        ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
+        ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
+        ('lp_pool1d', (S, S, S), (2., 3, 2,)),
+        ('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
+        ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
+        ('adaptive_max_pool1d', (S, S, S), (5,)),
+        ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
+        ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
+        ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
+        ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
+        ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
+        ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
+        ('alpha_dropout', (S, S, S), (0.5,)),
+        ('dropout2d', (S, S, S), (0.5,)),
+        ('dropout2d', (S, S, S, S), (0.5,), 'batched'),
+        ('dropout3d', (S, S, S, S), (0.5,)),
+        ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
+        ('feature_alpha_dropout', (S, S, S), (0.5,)),
+        ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
+        ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
+        ('relu', (S, S, S), (), '', (True,)),
+        ('relu', (S, S, S), (), 'inplace'),
+        ('glu', (S - 1, S - 1, S - 1), (),),
+        ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
+        ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
+        ('relu6', (S, S, S), (), '', (True,)),
+        ('relu6', (S, S, S), (True), 'inplace'),
+        ('elu', (S, S, S), (0.9,),),
+        ('elu', (S, S, S), (0.9, True), 'inplace'),
+        ('selu', (S, S, S), (),),
+        ('selu', (S, S, S), (True), 'inplace'),
+        ('celu', (S, S, S), (0.9,),),
+        ('celu', (S, S, S), (0.9, True), 'inplace'),
+        ('leaky_relu', (S, S, S), (0.02,), '', (True,)),
+        ('leaky_relu', (S, S, S), (0.02,), 'inplace'),
+        ('rrelu', (S, S), (0.1, 0.3, False),),
+        ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
+        ('hardshrink', (S, S, S), (0.4,), '', (True,)),
+        ('tanhshrink', (S, S, S), (),),
+        ('softsign', (S, S, S), (),),
+        ('softplus', (S, S, S), (), '', (True,)),
+        ('softmin', (S, S, S), (0,),),
+        ('softmax', (S, S, S), (0,), '', (True,)),
+        ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
+        ('tanh', (S, S, S), (), '', (True,)),
+        ('sigmoid', (S, S, S), (), '', (True,)),
+        ('silu', (S, S, S), (), '', (True,)),
+        ('log_softmax', (S, S, S), (0,), '', (True,)),
+        ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
+        ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
+        ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
+        ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
+        ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
+        ('batch_norm', (S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
+            'training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (0, S, S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'size_zero', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (0, S, S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, non_differentiable(torch.ones(S)), True, ),
+            'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), None, True, ),
+            'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, None, False, ),
+            'inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
+            'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, non_differentiable(torch.ones(S)), False, ),
+            'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), None, False, ),
+            'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
+        ('layer_norm', (S, S, S, S), ([5],), '',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
+                                      non_differentiable(torch.rand(S))), 'with_weight_and_bias',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
+        ('group_norm', (S, S, S), (1, torch.rand(5),),),
+        ('local_response_norm', (S, S, S), (2, ),),
+        ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
+        ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
+        ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
+        ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
+        ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
+        ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
+        ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('margin_ranking_loss', (S,), ((S,), (S,)),),
+        ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
+        ('pixel_shuffle', (1, 9, 4, 4), (3,),),
+        ('pixel_unshuffle', (1, 1, 12, 12), (3,),),
+        ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
+        ('pad', (3, 3, 4, 2), ([1, 1],),),
+        ('pairwise_distance', (S, S), ((S, S),),),
+        ('pdist', (S, S), (),),
+        ('cosine_similarity', (S, S), ((S, S),),),
+        ('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
+        ('normalize', (S, S, S), (),),
+        ('unfold', (S, S, S, S), ([2, 3]),),
+        ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
+        ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
+        ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
+        ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
+        ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
+        ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
+                                       1, 1., non_differentiable(torch.randn(S))),),
+        ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
+                                                               non_differentiable(torch.randn(3, 2))),),
+        ('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
+            (non_differentiable(torch.rand(3, 2)),
+             non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
+        ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
+         (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
+          torch.randint(1, S, (S,), dtype=torch.long))),
+        ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
+        ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
+         'nearest_4d_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
+         'nearest_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
+         'bilinear_4d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
+         'bilinear_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
+         'bicubic_4d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
+         'bicubic_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
+         'nearest_3d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
+         'nearest_3d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
+         'linear_3d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
+         'linear_3d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
+         'nearest_5d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
+         'nearest_5d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
+         'trilinear_5d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
+         'trilinear_5d_with_size_not_recompute_scale_factor'),
+    ]
+    return nn_functional_tests
+
+script_template = '''
+def the_method({}):
+    return {}
+'''
+
+def value_to_literal(value):
+    if isinstance(value, str):
+        # Quotes string and escapes special characters
+        return ascii(value)
+    if isinstance(value, torch.Tensor):
+        return 'torch.' + str(value)
+    else:
+        return str(value)
+
+def get_call(method_name, func_type, args, kwargs):
+    kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
+    self_arg = args[0]
+    if func_type == 'method':
+        args = args[1:]
+
+    argument_str = ', '.join(args)
+    argument_str += ', ' if len(args) and len(kwargs) else ''
+    argument_str += kwargs_str
+
+    if func_type == 'functional' or func_type == 'function':
+        call = f'torch.{method_name}({argument_str})'
+    elif func_type == 'method':
+        call = f'{self_arg}.{method_name}({argument_str})'
+    elif func_type == 'nn_functional':
+        call = f'torch.nn.functional.{method_name}({argument_str})'
+    else:
+        raise TypeError('Unsupported function type')
+
+    return call
+
+def get_constant(x):
+    if x == inf:
+        return 'math.inf'
+    if x == -inf:
+        return '-math.inf'
+    return x
+
+def get_script_args(args):
+    formals: list[str] = []
+    tensors: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+    actuals: list[str] = []
+    for arg in args:
+        if isinstance(arg, torch.Tensor):
+            name = f'i{len(formals)}'
+            formals.append(name)
+            actuals.append(name)
+            tensors.append(arg)
+        elif is_iterable_of_tensors(arg):
+            name = f'i{len(formals)}'
+            formals.append(name + ': List[torch.Tensor]')
+            actuals.append(name)
+            tensors.append(list(arg))
+        elif isinstance(arg, str):
+            actuals.append(f"'{arg}'")
+        else:
+            actuals.append(str(get_constant(arg)))
+    return (formals, tensors, actuals)
+
+# create a script function from (name, func_type, output_process_fn),
+# and returns the compiled function and example inputs
+def gen_script_fn_and_args(method_name, func_type, *args, **kwargs):
+    formals, tensors, actuals = get_script_args(args)
+    call = get_call(method_name, func_type, actuals, kwargs)
+    script = script_template.format(', '.join(formals), call)
+    CU = torch.jit.CompilationUnit(script)
+    return CU.the_method, tensors
+
+# create a script function from (name, func_type),
+# returns a function takes in (args, kwargs) and runs the compiled function
+def create_script_fn(self, method_name, func_type):
+    # function returns tuple containing original output and
+    # filtered output to be used in checking gradients
+    def script_fn(*args, **kwargs):
+        fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs)
+        self.assertExportImport(fn.graph, tensors)
+        output = fn(*tensors)
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        script_fn.last_graph = fn.graph_for(*tensors)  # type: ignore[attr-defined]
+        return output
+    return script_fn
+
+class SplitInputs:
+    all_tensors: list[Any]
+    tensor_args: list[Any]
+    nontensor_args: list[Any]
+    arg_types: list[str]
+    tensor_kwargs: dict[str, Any]
+    kwarg_order: list[str]
+    nontensor_kwargs: dict[str, Any]
+    kwarg_types: dict[str, Any]
+
+    @staticmethod
+    def _is_tensor_input(arg):
+        return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)
+
+    def __init__(self, args, kwargs):
+        self.arg_types = ['t' if self._is_tensor_input(arg) else 's' for arg in args]
+        self.kwarg_types = {k: 't' if self._is_tensor_input(v) else 's' for k, v in kwargs.items()}
+        self.tensor_args = [arg for arg in args if self._is_tensor_input(arg)]
+        self.nontensor_args = [arg for arg in args if not self._is_tensor_input(arg)]
+        self.tensor_kwargs = {k: v for k, v in kwargs.items() if self._is_tensor_input(v)}
+        self.nontensor_kwargs = {k: v for k, v in kwargs.items() if not self._is_tensor_input(v)}
+        self.all_tensors = [*self.tensor_args, *[v for k, v in self.tensor_kwargs.items()]]
+        self.kwarg_order = [k for k, v in kwargs.items()]
+
+    def nontensors_match(self, other: 'SplitInputs'):
+        if self.arg_types != other.arg_types:
+            return False
+        if self.kwarg_types != other.kwarg_types:
+            return False
+        if self.kwarg_order != other.kwarg_order:
+            return False
+        if self.nontensor_args != other.nontensor_args:
+            return False
+        if self.nontensor_kwargs != other.nontensor_kwargs:
+            return False
+        return True
+
+# make a new function where all non-tensor arguments in 'args' have been partially
+# applied, and all tensor arguments remain.
+# used to trace functions when some arguments are not tensors
+def partial_apply_nontensors(fn, args, kwargs):
+    inputs = SplitInputs(args, kwargs)
+
+    def new_fn(*tensors_):
+        tensors = iter(tensors_)
+        full_args = [args[i] if s == 's' else next(tensors) for i, s in enumerate(inputs.arg_types)]
+        full_kwargs = {k: kwargs[k] if s == 's' else next(tensors) for k, s in inputs.kwarg_types.items()}
+        return fn(*full_args, **full_kwargs)
+
+    return new_fn, inputs
+
+# create a trace function from input fn
+def create_traced_fn(self, fn, cache_traced_fn=False):
+    def traced_fn(*inputs, **kwargs):
+        # `check_trace` is set to False because check_trace is run with @no_grad
+        # Also, `check_against_reference` already does all the checks
+        # against python function
+        fn_tensors, split_inputs = partial_apply_nontensors(fn, inputs, kwargs)
+        if not cache_traced_fn or not hasattr(traced_fn, 'traced'):
+            traced = torch.jit.trace(fn_tensors, split_inputs.all_tensors, check_trace=False)
+            self.assertExportImport(traced.graph, split_inputs.all_tensors)
+            output = traced(*split_inputs.all_tensors)
+            if cache_traced_fn:
+                traced_fn.traced = traced
+                traced_fn.split_inputs = split_inputs
+        else:
+            # Guard to check that nontensor inputs are the same as during tracing
+            self.assertTrue(traced_fn.split_inputs.nontensors_match(split_inputs))
+            output = traced_fn.traced(*split_inputs.all_tensors)
+            traced = traced_fn.traced
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        traced_fn.last_graph = traced.graph_for(*split_inputs.all_tensors)  # type: ignore[attr-defined]
+        traced_fn.graph = traced.graph  # type: ignore[attr-defined]
+        return output
+    return traced_fn
+
+# known to be failing in script
+EXCLUDE_SCRIPT = {
+    'test_norm_fro_default',
+    'test_norm_fro_cpu',
+    'test_norm_nuc',
+    'test_norm_fro',
+    'test_norm_nuc_batched',
+
+    # aten op has additional cudnn argument
+    'test_nn_unfold',
+
+    # flaky test - TODO fix
+    'test_nn_ctc_loss',
+
+    # unknown builtin op
+    'test_nn_fold',
+
+    # jit doesn't support sparse tensors.
+    'test_to_sparse',
+    'test_to_sparse_dim',
+}
+
+# generates a script function and set of example inputs
+# from a specified test in the format of nn_functional_tests
+def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args):
+    test_name = 'test_nn_' + name
+
+    if variant_name != '':
+        test_name = test_name + '_' + variant_name
+
+    self_variable = create_input((self_size,))[0][0]
+
+    # need to record this because methods can change the size (e.g. unsqueeze)
+    args_variable, _kwargs_variable = create_input(args)
+
+    self_tensor = deepcopy(self_variable.data)
+    args_tensor = deepcopy(unpack_variables(args_variable))
+
+    f_args_variable = (self_variable,) + args_variable
+    f_args_tensor = (self_tensor,) + args_tensor  # noqa: F841
+    with torch._jit_internal._disable_emit_hooks():
+        script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
+    return script_fn, inputs
+
+
+
+EXCLUDE_SCRIPT_MODULES = {
+    'test_nn_AdaptiveAvgPool2d_tuple_none',
+    'test_nn_AdaptiveAvgPool3d_tuple_none',
+    'test_nn_AdaptiveMaxPool2d_tuple_none',
+    'test_nn_AdaptiveMaxPool3d_tuple_none',
+
+    # Doesn't use future division, so this is not supported
+    'test_nn_CrossMapLRN2d',
+    # Derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented
+    'test_nn_TransformerDecoderLayer_gelu_activation',
+    'test_nn_TransformerDecoderLayer_relu_activation',
+    'test_nn_TransformerEncoderLayer_gelu_activation',
+    'test_nn_TransformerEncoderLayer_relu_activation',
+    'test_nn_Transformer_multilayer_coder',
+}
+
+script_method_template = '''
+def forward({}):
+    return {}
+'''
+
+def create_script_module(self, nn_module, constructor_args, *args, **kwargs):
+    def script_module(*args, **kwargs):
+        _formals, tensors, actuals = get_script_args(args)
+
+        method_args = ', '.join(['self'] + actuals)
+        call_args_str = ', '.join(actuals)
+        call = f"self.submodule({call_args_str})"
+        script = script_method_template.format(method_args, call)
+
+        submodule_constants = []
+        if kwargs.get('is_constant'):
+            submodule_constants = ['submodule']
+
+        # Create module to use the script method
+        class TheModule(torch.jit.ScriptModule):
+            __constants__ = submodule_constants
+
+            def __init__(self) -> None:
+                super().__init__()
+                self.submodule = nn_module(*constructor_args)
+
+        def make_module(script):
+            module = TheModule()
+            # check __repr__
+            str(module)
+            module.define(script)
+            return module
+
+        module = make_module(script)
+        if self:
+            self.assertExportImportModule(module, tensors)
+            module(*args)
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        create_script_module.last_graph = module.graph  # type: ignore[attr-defined]
+        return module
+    return script_module
+
+def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'):
+    formals, tensors, actuals = get_script_args(args)
+    call = get_call(method_name, func_type, actuals, kwargs)
+    script = script_template.format(', '.join(formals), call)
+    CU = torch.jit.CompilationUnit(script)
+    # to clean up IR
+    torch._C._jit_pass_inline(CU.the_method.graph)
+    torch._C._jit_pass_constant_propagation(CU.the_method.graph)
+    torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name)
+
+def get_nn_module_name_from_kwargs(**kwargs):
+    if 'module_name' in kwargs:
+        return kwargs['module_name']
+    elif 'fullname' in kwargs:
+        return kwargs['fullname']
+    elif 'constructor' in kwargs:
+        return kwargs['constructor'].__name__
+
+def get_nn_mod_test_name(**kwargs):
+    if 'fullname' in kwargs:
+        test_name = kwargs['fullname']
+    else:
+        test_name = get_nn_module_name_from_kwargs(**kwargs)
+        if 'desc' in kwargs:
+            test_name = f"{test_name}_{kwargs['desc']}"
+    return f'test_nn_{test_name}'
+
+def get_nn_module_class_from_kwargs(**kwargs):
+    name = get_nn_module_name_from_kwargs(**kwargs)
+    index = name.find("_")
+    if index == -1:
+        return name
+    else:
+        return name[0:name.find("_")]
+
+def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
+    name = get_nn_module_name_from_kwargs(**kwargs)
+
+    if 'desc' in kwargs and 'eval' in kwargs['desc']:
+        # eval() is not supported, so skip these tests
+        return
+
+    test_name = name
+    if 'desc' in kwargs:
+        test_name = f"{test_name}_{kwargs['desc']}"
+    test_name = get_nn_mod_test_name(**kwargs)
+
+    if test_name in EXCLUDE_SCRIPT_MODULES:
+        return
+    if 'constructor' in kwargs:
+        nn_module = kwargs['constructor']
+    else:
+        nn_module = getattr(torch.nn, name)
+
+    if "FunctionalModule" in str(nn_module):
+        return
+
+    if 'constructor_args_fn' in kwargs:
+        constructor_args = kwargs['constructor_args_fn']()
+    else:
+        constructor_args = kwargs.get('constructor_args', ())
+
+    # Set up inputs from tuple of sizes or constructor fn
+    input_dtype = torch.double
+    if 'input_fn' in kwargs:
+        input = kwargs['input_fn']()
+        if isinstance(input, torch.Tensor):
+            input = (input,)
+
+        if all(tensor.is_complex() for tensor in input):
+            input_dtype = torch.cdouble
+    else:
+        input = (kwargs['input_size'],)
+
+    # Extra parameters to forward()
+    if 'extra_args' in kwargs:
+        input = input + kwargs['extra_args']
+
+    if 'target_size' in kwargs:
+        input = input + (kwargs['target_size'],)
+    elif 'target_fn' in kwargs:
+        if torch.is_tensor(input):
+            input = (input,)
+        input = input + (kwargs['target_fn'](),)
+
+    args_variable, _kwargs_variable = create_input(input, dtype=input_dtype)
+    f_args_variable = deepcopy(unpack_variables(args_variable))
+    out_var = deepcopy(f_args_variable)
+
+
+    _args, mod = f_args_variable, create_script_module(
+        None, nn_module, constructor_args, *f_args_variable
+    )(*f_args_variable)
+
+    return mod, out_var
+
+
+def get_all_nn_module_tests():
+    # additional modules test
+    # TODO: delete this list once we make all nn_tests work
+    additional_module_tests = [
+        {
+            'module_name': 'Bilinear',
+            'constructor_args': (S, S, M),
+            'input_size': (S, S),
+            'extra_args': ((S, S),)
+        },
+        {
+            'module_name': 'RNNCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'LSTMCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'GRUCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'MultiheadAttention',
+            'constructor_args': (128, 8),
+            'input_size': (10, 8, 128),
+            'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
+            'slowTest': True
+        },
+        {
+            'module_name': 'Transformer',
+            'constructor_args': (1, 1, 1, 1, 2),
+            'input_size': (3, 1, 1),
+            'extra_args': (torch.randn(1, 1, 1),),
+            'slowTest': True
+        }
+    ]
+
+    return module_tests + get_new_module_tests() + additional_module_tests
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aab838e8c87b229a824f1b4548f035cea614bfb
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py
@@ -0,0 +1,896 @@
+# mypy: ignore-errors
+
+# Torch
+from torch.autograd import Variable
+from torch.autograd.function import _nested_map
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
+
+from torch.onnx import OperatorExportTypes
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+import torch.jit.quantized
+import zipfile
+import functools
+
+# Testing utils
+from torch.testing import FileCheck
+from torch.testing._internal.common_utils import IS_WINDOWS, \
+    freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \
+    is_iterable_of_tensors
+from torch.testing._internal.common_jit import JitCommonTestCase
+from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401
+
+# Standard library
+from contextlib import contextmanager
+from functools import reduce
+from io import StringIO
+from collections import defaultdict
+
+import importlib.util
+import inspect
+import io
+import math
+import os
+import pickle
+import sys
+import tempfile
+import textwrap
+from importlib.abc import Loader
+from typing import Any, Union
+
+RUN_CUDA = torch.cuda.is_available()
+RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
+RUN_CUDA_HALF = RUN_CUDA
+# HIP supports half, no version check necessary
+if torch.cuda.is_available() and not torch.version.hip:
+    CUDA_VERSION = torch._C._cuda_getCompiledVersion()
+    for d in range(torch.cuda.device_count()):
+        major = torch.cuda.get_device_capability(d)[0]
+        if (major < 6):
+            RUN_CUDA_HALF = False
+
+def execWrapper(code, glob, loc):
+    exec(code, glob, loc)
+
+def do_input_map(fn, input):
+    return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
+
+def clear_class_registry():
+    torch._C._jit_clear_class_registry()
+    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
+    torch.jit._state._clear_class_state()
+
+def get_execution_plan(graph_executor_state):
+    execution_plans = list(graph_executor_state.execution_plans.values())
+    num_plans = len(execution_plans)
+    if num_plans != 1:
+        raise RuntimeError('This test assumes this GraphExecutor should '
+                           f'only have one execution plan, got: {num_plans}')
+    return execution_plans[0]
+
+class _AssertRaisesRegexWithHighlightContext:
+    """
+    A context manager that is useful for checking that error messages highlight
+    the correct part of the source code.
+    """
+
+    def __init__(self, test_case, exception, regex, highlight):
+        self.test_case = test_case
+        self.exception_type = exception
+        self.regex = regex
+        self.highlight = highlight
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        with self.test_case.assertRaisesRegex(self.exception_type, self.regex):
+            if type:
+                raise value
+
+        if self.highlight:
+            FileCheck().check_source_highlighted(self.highlight).run(str(value))
+
+        return True
+
+FUSION_GROUP = "prim::TensorExprGroup"
+
+class JitTestCase(JitCommonTestCase):
+    _do_cuda_memory_leak_check = True
+    _restored_warnings = False
+
+    class capture_stdout(list):
+        """
+        Replace sys.stdout with a temporary StringIO
+        """
+        def __enter__(self):
+            self.sys_stdout = sys.stdout
+            self.stringio = StringIO()
+            sys.stdout = self.stringio
+            return self
+
+        def __exit__(self, *args):
+            self.append(str(self.stringio.getvalue()))
+            del self.stringio
+            sys.stdout = self.sys_stdout
+
+    class capture_stderr(list):
+        """
+        Replace sys.stderr with a temporary StringIO
+        """
+        def __enter__(self):
+            self.sys_stderr = sys.stderr
+            self.stringio = StringIO()
+            sys.stderr = self.stringio
+            return self
+
+        def __exit__(self, *args):
+            self.append(str(self.stringio.getvalue()))
+            del self.stringio
+            sys.stderr = self.sys_stderr
+
+    def setHooks(self):
+        torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)
+
+    def clearHooks(self):
+        torch._C._jit_set_emit_hooks(None, None)
+
+    def setUp(self):
+        super().setUp()
+        # unittest overrides all warning filters and forces all of them to show up
+        # after we install our own to silence those coming from inside PyTorch.
+        # This will ensure that our filter still takes precedence.
+        if not JitTestCase._restored_warnings:
+            torch.jit.TracerWarning.ignore_lib_warnings()
+            JitTestCase._restored_warnings = True
+        self.setHooks()
+
+    def tearDown(self):
+        super().tearDown()
+        # needs to be cleared because python might be unloaded before
+        # the callback gets destructed
+        self.clearHooks()
+        clear_class_registry()
+
+    def assertAllFused(self, graph, except_for=()):
+
+        # note this helper collects nodes on 'fast path' only
+        # i.e. the true blocks of specialized checks
+        def get_nodes_and_parents_recursively(block, kind, acc):
+            for node in block.nodes():
+                if node.kind() == kind:
+                    acc[block].append(node)
+                elif node.kind() == 'prim::DifferentiableGraph':
+                    get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc)
+                elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
+                                                    node.inputs().__next__().node().kind() == 'prim::TypeCheck' or
+                                                    node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'):
+                    get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc)
+                else:
+                    for inner_block in node.blocks():
+                        get_nodes_and_parents_recursively(inner_block, kind, acc)
+
+        allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
+                         'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)
+
+        fusion_groups : dict[torch._C.Block, list[torch._C.Node]] = defaultdict(list)
+        get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
+        self.assertTrue(len(fusion_groups) == 1, f'got {graph}')
+        (graph, fusion_nodes) = next(iter(fusion_groups.items()))
+        # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes`
+        self.assertTrue(len(fusion_nodes) == 1, f'got {graph}')
+        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
+                        f'got {graph}')
+
+    def _isHookExceptionOk(self, e):
+        se = str(e)
+        allowed = ("Could not export Python function",
+                   "closures are not exportable")
+        for a in allowed:
+            if a in se:
+                return True
+        return False
+
+    def _compared_saved_loaded(self, m):
+        def extract_files(buffer):
+            # crack open the zip format to get at the main module code
+            with zipfile.ZipFile(buffer) as archive:
+                # check that we have no duplicate names
+                self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
+                files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
+                # unwrap all the code files into strings
+                code_files_str = filter(lambda x: x.endswith('.py'), files)
+                code_files = []
+                for f in code_files_str:
+                    with archive.open(f) as stream:
+                        code_files.append("".join([line.decode() for line in stream]))
+
+                # unpickled all the debug files
+                debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
+                debug_files = []
+                for f in debug_files_str:
+                    with archive.open(f) as stream:
+                        debug_files.append(pickle.load(stream))
+                return code_files, debug_files
+
+        # disable the hook while we parse code, otherwise we will re-enter the hook
+        with torch._jit_internal._disable_emit_hooks():
+            try:
+                # short-circuit if this is an empty function or module
+                if len(m.code) == 0:
+                    return
+                if isinstance(m, torch._C.ScriptModule):
+                    if len(m._method_names()) == 0:
+                        return
+
+                # save the module to a buffer
+                buffer = io.BytesIO()
+                torch.jit.save(m, buffer)
+                # copy the data in the buffer so we can restore it later. This
+                # is because py2 and py3 have different semantics with zipfile
+                # and it's easier to just work with a fresh copy each time.
+                buffer_copy = buffer.getvalue()
+
+                code_files, _debug_files = extract_files(buffer)
+
+            except RuntimeError as e:
+                if not self._isHookExceptionOk(e):
+                    raise
+                else:
+                    return
+
+            # import the model again (from a the copy we made of the original)
+            buffer2 = io.BytesIO(buffer_copy)
+            imported = torch.jit.load(buffer2)
+
+            # save it again
+            saved_module_buffer_2 = io.BytesIO()
+            torch.jit.save(imported, saved_module_buffer_2)
+
+            saved_module_buffer_2.seek(0)
+            code_files_2, _debug_files_2 = extract_files(saved_module_buffer_2)
+
+            for a, b in zip(code_files, code_files_2, strict=True):
+                self.assertMultiLineEqual(a, b)
+
+            if isinstance(m, torch._C.ScriptModule):
+                self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))
+
+
+    def emitFunctionHook(self, func):
+        # func has invalid names for export, skip the jitter check
+        if func.name == "" or "aten::" in func.name:
+            return
+        self._compared_saved_loaded(func)
+
+    def emitModuleHook(self, module):
+        self._compared_saved_loaded(module)
+
+
+    def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
+        buffer = io.BytesIO()
+        m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None)
+        torch.jit.save(m, buffer)
+        m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+        buffer.seek(0)
+        imported = torch.jit.load(buffer, map_location=map_location)
+        imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+
+        if not also_test_file:
+            return imported
+
+        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+        # close the file after creation and try to remove it manually
+        with tempfile.NamedTemporaryFile(delete=False) as f:
+            try:
+                f.close()
+                imported.save(f.name)
+                result = torch.jit.load(f.name, map_location=map_location)
+            finally:
+                os.unlink(f.name)
+
+        result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+        return result
+
+    def assertGraphContains(self, graph, kind, consider_subgraphs=False):
+
+        if consider_subgraphs:
+            strgraph = str(graph)
+            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
+            self.assertTrue(count > 0)
+            return
+
+        def nodes(block):
+            out = []
+            for node in block.nodes():
+                if node.kind() == kind:
+                    out.append(node)
+                for block in node.blocks():
+                    out += nodes(block)
+            return out
+
+        out_nodes = nodes(graph)
+        self.assertTrue(len(out_nodes) > 0)
+
+    def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
+        def perform_assert(graph, kind, actual, expected, consider_subgraphs):
+            if actual == expected:
+                return
+            subgraph = 'including' if consider_subgraphs else 'excluding'
+            raise AssertionError(
+                f'{graph}\nError: graph contains {actual} {kind} nodes ({subgraph} subgraphs) but expected {expected}')
+
+        if consider_subgraphs:
+            strgraph = str(graph)
+            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
+            perform_assert(graph, kind, count, num_kind_nodes,
+                           consider_subgraphs)
+            return
+
+        def nodes(block):
+            out = []
+            for node in block.nodes():
+                if node.kind() == kind:
+                    out.append(node)
+                for block in node.blocks():
+                    out += nodes(block)
+            return out
+
+        out_nodes = nodes(graph)
+        perform_assert(graph, kind, len(out_nodes), num_kind_nodes,
+                       consider_subgraphs)
+
+    def assertExpectedONNXGraph(self, g, *args, **kwargs):
+        g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
+        self.assertExpectedGraph(g, *args, **kwargs)
+
+    def assertExpectedGraph(self, trace, *args, **kwargs):
+        if isinstance(trace, torch._C.Graph):
+            graph = trace
+        else:
+            graph = trace.graph()
+
+        torch._C._jit_pass_lint(graph)
+        torch._C._jit_pass_dce(graph)
+        torch._C._jit_pass_lint(graph)
+        graph = torch._C._jit_pass_canonicalize(graph)
+        torch._C._jit_pass_lint(graph)
+        self.assertExpected(str(graph), *args, **kwargs)
+
+    def run_pass(self, name, trace):
+        if isinstance(trace, torch._C.Graph):
+            graph = trace
+            set_graph = False
+        else:
+            set_graph = True
+            graph = trace.graph()
+
+        torch._C._jit_pass_lint(graph)
+        result = getattr(torch._C, '_jit_pass_' + name)(graph)
+        if result is not None and not isinstance(result, bool):
+            graph = result
+        torch._C._jit_pass_lint(graph)
+
+        if set_graph:
+            trace.set_graph(graph)
+        return graph
+
+    def get_frame_vars(self, frames_up):
+        frame = inspect.currentframe()
+        if not frame:
+            raise RuntimeError("failed to inspect frame")
+        i = 0
+        while i < frames_up + 1:
+            frame = frame.f_back
+            if not frame:
+                raise RuntimeError("failed to get frame")
+            i += 1
+        defined_vars: dict[str, Any] = {}
+        defined_vars.update(frame.f_locals)
+        defined_vars.update(frame.f_globals)
+        return defined_vars
+
+    def assertRaisesRegexWithHighlight(self, exception, regex, highlight):
+        return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight)
+
+    def checkScriptRaisesRegex(self, script, inputs, exception, regex,
+                               name=None, outputs=None, capture_output=False,
+                               frames_up=1, profiling=ProfilingMode.PROFILING):
+        """
+        Checks that a given function will throw the correct exception,
+        when executed with normal python, the string frontend, and the
+        AST frontend. Logic taken from `checkScript` (see comments there
+        for details)
+        """
+        with enable_profiling_mode_for_profiling_tests():
+            # Normal Python
+            with self.assertRaisesRegex(exception, regex):
+                if isinstance(script, str):
+                    frame = self.get_frame_vars(frames_up)
+                    the_locals: dict[str, Any] = {}
+                    execWrapper(script, glob=frame, loc=the_locals)
+                    frame.update(the_locals)
+
+                    python_fn = frame[name]
+                else:
+                    python_fn = script
+
+                python_fn(*inputs)
+
+            # String frontend
+            with self.assertRaisesRegex(exception, regex):
+                if isinstance(script, str):
+                    cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+                    string_frontend = getattr(cu, name)
+                else:
+                    source = textwrap.dedent(inspect.getsource(script))
+                    cu = torch.jit.CompilationUnit(source, _frames_up=frames_up)
+                    string_frontend = getattr(cu, script.__name__)
+
+                string_frontend(*inputs)
+
+            # Python AST frontend
+            if not isinstance(script, str):
+                with self.assertRaisesRegex(exception, regex):
+                    ge = torch.jit.script(python_fn)
+                    ge(*inputs)
+
+    def checkBailouts(self, model, inputs, expected):
+        state = model.get_debug_state()
+        plan = get_execution_plan(state)
+        num_bailouts = plan.code.num_bailouts()
+        for i in range(num_bailouts):
+            plan.code.request_bailout(i)
+            bailout_outputs = model(*inputs)
+            self.assertEqual(bailout_outputs, expected)
+
+    def checkScript(self,
+                    script,
+                    inputs,
+                    name='func',
+                    optimize=True,
+                    inputs_requires_grad=False,
+                    capture_output=False,
+                    frames_up=1,
+                    profiling=ProfilingMode.PROFILING,
+                    atol=None,
+                    rtol=None):
+        """
+        Checks that a given script generates the same output as the Python
+        version using the given inputs.
+        """
+        with torch.jit.optimized_execution(optimize), enable_profiling_mode_for_profiling_tests():
+            extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs)
+            if isinstance(script, str):
+                # Compile the string to a Script function
+                # with enable_profiling_mode():
+                cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+
+                # Execute the Python function so we can run it later and get its
+                # outputs
+
+                frame = self.get_frame_vars(frames_up)
+                the_locals: dict[str, Any] = {}
+                execWrapper(script, glob=frame, loc=the_locals)
+                frame.update(the_locals)
+
+                python_fn = frame[name]
+                scripted_fn = getattr(cu, name)
+            else:
+
+                # Check the string frontend first
+                source = textwrap.dedent(inspect.getsource(script))
+                self.checkScript(
+                    source,
+                    inputs,
+                    script.__name__,
+                    optimize=optimize,
+                    inputs_requires_grad=inputs_requires_grad,
+                    capture_output=capture_output,
+                    profiling=profiling,
+                    frames_up=2)
+
+                # Continue checking the Python frontend
+                scripted_fn = torch.jit.script(script, _frames_up=1)
+                python_fn = script
+
+            if inputs_requires_grad:
+                recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
+            else:
+                recording_inputs = inputs
+
+            if capture_output:
+                with self.capture_stdout() as script_stdout:
+                    script_outputs = scripted_fn(*recording_inputs)
+                with self.capture_stdout():
+                    opt_script_outputs = scripted_fn(*recording_inputs)
+                with self.capture_stdout():
+                    python_outputs = python_fn(*inputs)
+                if not IS_WINDOWS:
+                    self.assertExpected(script_stdout[0], subname='stdout')
+                self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol)
+            else:
+                # profiling run
+                script_outputs = scripted_fn(*recording_inputs)
+                if inputs_requires_grad or extra_profile_runs:
+                    opt_script_outputs = scripted_fn(*recording_inputs)
+                # optimized run
+                opt_script_outputs = scripted_fn(*recording_inputs)
+                if TEST_BAILOUTS:
+                    self.checkBailouts(scripted_fn, inputs, opt_script_outputs)
+                python_outputs = python_fn(*inputs)
+            self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol)
+            self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol)
+            return scripted_fn
+
+    def checkTrace(self, func, reference_tensors, input_tensors=None,
+                   drop=None, allow_unused=False, verbose=False,
+                   inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
+                   _force_outplace=False, grad_atol=None, grad_rtol=None):
+
+        # TODO: check gradients for parameters, not just inputs
+        def allSum(vs):
+            # drop allows us to remove some values from ever being used
+            # to test unused outputs
+            if drop is not None:
+                vs = vs[:-drop]
+            # we don't want all the grad for all the outputs to be the same
+            # so we multiply each by a constant
+            return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
+        if input_tensors is None:
+            input_tensors = reference_tensors
+
+        def flatten_inputs(inputs):
+            def input_reduce(input, fn, acc):
+                if isinstance(input, torch.Tensor):
+                    fn(input, acc)
+                elif isinstance(input, dict):
+                    reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
+                else:
+                    reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
+                return acc
+            return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
+
+        nograd_inputs = reference_tensors
+        if inputs_require_grads:
+            recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
+        else:
+            recording_inputs = reference_tensors
+
+        # `check_trace` is set to False because check_trace is run with @no_grad
+        # Also, `checkTrace` already does all the checks
+        # against python function
+        ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance,
+                             _force_outplace=_force_outplace, check_trace=False)
+
+        if export_import:
+            ge = self.getExportImportCopy(ge)
+
+        if verbose:
+            print(ge.graph)
+
+        # test no gradients case
+        outputs = func(*nograd_inputs)
+        outputs_ge = ge(*nograd_inputs)
+        self.assertEqual(outputs, outputs_ge)
+
+        # test gradients case
+        outputs = func(*recording_inputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
+                                        allow_unused=allow_unused)
+
+        outputs_ge = ge(*recording_inputs)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
+                                           allow_unused=allow_unused)
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
+
+        # test the grad grad case
+        outputs = func(*recording_inputs)
+        l1 = allSum(outputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
+                                        allow_unused=allow_unused)
+        if inputs_require_grads:
+            l2 = (allSum(grads) * l1)
+            grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
+
+        outputs_ge = ge(*recording_inputs)
+        l1_ge = allSum(outputs_ge)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(
+                l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            l2_ge = (allSum(grads_ge) * l1_ge)
+            grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
+
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
+            for g2, g2_ge in zip(grads2, grads2_ge, strict=True):
+                if g2 is None and g2_ge is None:
+                    continue
+                self.assertEqual(g2, g2_ge, atol=8e-4, rtol=8e-4)
+
+        return ge
+
+    def checkModule(self, nn_module, args):
+        """
+        Check that a nn.Module's results in Script mode match eager and that it
+        can be exported
+        """
+        sm = torch.jit.script(nn_module)
+
+        with freeze_rng_state():
+            eager_out = nn_module(*args)
+
+        with freeze_rng_state():
+            script_out = sm(*args)
+
+        self.assertEqual(eager_out, script_out)
+        self.assertExportImportModule(sm, args)
+
+        return sm
+
+class NoTracerWarnContextManager:
+    def __enter__(self):
+        self.prev = torch._C._jit_get_tracer_state_warn()
+        torch._C._jit_set_tracer_state_warn(False)
+
+    def __exit__(self, *args):
+        torch._C._jit_set_tracer_state_warn(self.prev)
+
+@contextmanager
+def inline_everything_mode(should_inline):
+    old = torch._C._jit_get_inline_everything_mode()
+    torch._C._jit_set_inline_everything_mode(should_inline)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_inline_everything_mode(old)
+
+@contextmanager
+def set_fusion_group_inlining(inlining):
+    old = torch._C._debug_get_fusion_group_inlining()
+    torch._C._debug_set_fusion_group_inlining(inlining)
+    try:
+        yield
+    finally:
+        torch._C._debug_set_fusion_group_inlining(old)
+
+# note: not re-entrant, use unnested only
+@contextmanager
+def disable_autodiff_subgraph_inlining(enabled=True):
+    torch._C._debug_set_autodiff_subgraph_inlining(not enabled)
+    try:
+        yield
+    finally:
+        torch._C._debug_set_autodiff_subgraph_inlining(True)
+
+def _inline_everything(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        with inline_everything_mode(True):
+            fn(*args, **kwargs)
+    return wrapper
+
+# this exists for forward compatibility reasons temporarily.
+# TODO(suo) remove
+def _tmp_donotuse_dont_inline_everything(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        with inline_everything_mode(False):
+            fn(*args, **kwargs)
+    return wrapper
+
+# make it easy to quickly define/trace a function for these tests
+def _trace(*args, **kwargs):
+    def wrapper(func):
+        return torch.jit.trace(func, args, **kwargs)
+    return wrapper
+
+
+def enable_cpu_fuser(fn):
+    def wrapper(*args, **kwargs):
+        torch._C._jit_override_can_fuse_on_cpu_legacy(True)
+        torch._C._jit_override_can_fuse_on_cpu(True)
+        torch._C._jit_set_te_must_use_llvm_cpu(False)
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch._C._jit_override_can_fuse_on_cpu_legacy(False)
+            torch._C._jit_override_can_fuse_on_cpu(False)
+            torch._C._jit_set_te_must_use_llvm_cpu(True)
+    return wrapper
+
+
+def enable_cpu_fuser_if(cond):
+    if cond:
+        return enable_cpu_fuser
+    else:
+        def noop_fuser(fn):
+            def wrapper(*args, **kwargs):
+                return fn(*args, **kwargs)
+            return wrapper
+        return noop_fuser
+
+def get_forward(c):
+    return c._get_method('forward')
+
+def get_forward_graph(c):
+    return c._get_method('forward').graph
+
+def get_module_method(m, module, method):
+    return m._c.getattr(module)._get_method(method)
+
+def attrs_with_prefix(module, prefix):
+    return [x for x, _ in module._modules._c.items()
+            if x.startswith(prefix)]
+
+def warmup_backward(f, *args):
+    profiling_count = 3
+    results = []
+    for _ in range(profiling_count):
+        if len(args) > 0:
+            r = torch.autograd.grad(f, *args)
+            results.append(r)
+        else:
+            f.backward(retain_graph=True)
+
+    return results
+
+# TODO: Remove me once https://bugs.python.org/issue42666 is resolved
+def make_global(*args):
+    for arg in args:
+        setattr(sys.modules[arg.__module__], arg.__name__, arg)
+
+# Helper function to eval Python3 code without causing a syntax error for
+# this file under py2
+def _get_py3_code(code, fn_name):
+    with tempfile.TemporaryDirectory() as tmp_dir:
+        script_path = os.path.join(tmp_dir, 'script.py')
+        with open(script_path, 'w') as f:
+            f.write(code)
+        spec = importlib.util.spec_from_file_location(fn_name, script_path)
+        module = importlib.util.module_from_spec(spec)
+        loader = spec.loader
+        assert isinstance(loader, Loader)  # Assert type to meet MyPy requirement
+        loader.exec_module(module)
+        fn = getattr(module, fn_name)
+        return fn
+
+class TensorExprTestOptions:
+    def __init__(self) -> None:
+        self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
+        self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
+
+        self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
+        self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
+        torch._C._jit_override_can_fuse_on_cpu(True)
+        torch._C._jit_override_can_fuse_on_gpu(True)
+        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
+        torch._C._jit_set_texpr_fuser_enabled(True)
+        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
+        torch._C._debug_set_fusion_group_inlining(False)
+        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
+        torch._C._jit_set_te_must_use_llvm_cpu(False)
+
+    def restore(self):
+        torch._C._jit_set_profiling_executor(self.old_profiling_executor)
+        torch._C._get_graph_executor_optimize(self.old_profiling_mode)
+
+        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
+        torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
+        torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
+        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
+        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
+
+def clone_inputs(args):
+    inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+
+    for arg in args:
+        if isinstance(arg, torch.Tensor):
+            inputs.append(arg.detach().clone())
+        elif is_iterable_of_tensors(arg):
+            inputs.append([t.detach().clone() for t in arg])
+        else:
+            inputs.append(arg)
+
+    return inputs
+
+def get_traced_sample_variant_pairs(device, dtype, op):
+    # tuples of (variant, sample)
+    outputs: list[tuple[Any, Any]] = []
+
+    samples = op.sample_inputs(device, dtype)
+
+    # Acquires variants to test
+    func = op.get_op()
+    method = op.get_method()
+    variants = {
+        # TODO: inplace tests currently fail, fix and add inplace variant
+        'function': func, 'method': method,
+    }
+
+    # TODO: find better way to standardize on op registration itself..
+    has_fake_function = op.name in ["resize_", 'resize_as_']
+
+    if has_fake_function:
+        variants = {'method': getattr(torch.Tensor, op.name)}
+
+    # In eager mode, these ops can take (Tensor, bool) args; but in
+    # JIT they can only take (Tensor, Scalar), and bool is not a
+    # scalar in the JIT type system. So to test these in JIT, the bool
+    # is converted to an int for the test.
+    ops_with_unsupported_bool_args = [
+        {
+            "name": "div_floor_rounding",
+            "arg_idx": [0],
+        },
+        {
+            "name": "div_no_rounding_mode",
+            "arg_idx": [0],
+        },
+        {
+            "name": "div_trunc_rounding",
+            "arg_idx": [0],
+        },
+        {
+            "name": "index_fill",
+            "arg_idx": [2],
+        },
+        {
+            "name": "full_like",
+            "arg_idx": [0],
+        },
+        {
+            "name": "mul",
+            "arg_idx": [0],
+        },
+        {
+            "name": "new_full",
+            "arg_idx": [1],
+        },
+    ]
+
+    # doesn't support tracing
+    if has_fake_function:
+        return outputs
+
+    for sample in samples:
+        for variant in variants.values():
+            if variant is None:
+                continue
+
+            if is_lambda(variant):
+                continue
+
+            matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args)
+            for op_data in matching_ops:
+                for idx in op_data["arg_idx"]:
+                    args = list(sample.args)
+                    if len(sample.args) > idx and isinstance(sample.args[idx], bool):
+                        args[idx] = int(args[idx])
+                    sample.args = tuple(args)
+
+            outputs.append((variant, sample))
+
+    return outputs
+
+# types.LambdaType gave false positives
+def is_lambda(lamb):
+    LAMBDA = lambda: 0  # noqa: E731
+    return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e1ecf8f4f707c9b3712a6fb738fc9ce1467b835
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py
@@ -0,0 +1,243 @@
+# mypy: ignore-errors
+
+import torch._dynamo.test_case
+import unittest.mock
+import os
+import contextlib
+import torch._logging
+import torch._logging._internal
+from contextlib import AbstractContextManager
+from collections.abc import Callable
+from torch._dynamo.utils import LazyString
+from torch._inductor import config as inductor_config
+import logging
+import io
+
+@contextlib.contextmanager
+def preserve_log_state():
+    prev_state = torch._logging._internal._get_log_state()
+    torch._logging._internal._set_log_state(torch._logging._internal.LogState())
+    try:
+        yield
+    finally:
+        torch._logging._internal._set_log_state(prev_state)
+        torch._logging._internal._init_logs()
+
+def log_settings(settings):
+    exit_stack = contextlib.ExitStack()
+    settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
+    exit_stack.enter_context(preserve_log_state())
+    exit_stack.enter_context(settings_patch)
+    torch._logging._internal._init_logs()
+    return exit_stack
+
+def log_api(**kwargs):
+    exit_stack = contextlib.ExitStack()
+    exit_stack.enter_context(preserve_log_state())
+    torch._logging.set_logs(**kwargs)
+    return exit_stack
+
+
+def kwargs_to_settings(**kwargs):
+    INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
+
+    settings = []
+
+    def append_setting(name, level):
+        if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
+            settings.append(INT_TO_VERBOSITY[level] + name)
+            return
+        else:
+            raise ValueError("Invalid value for setting")
+
+    for name, val in kwargs.items():
+        if isinstance(val, bool):
+            settings.append(name)
+        elif isinstance(val, int):
+            append_setting(name, val)
+        elif isinstance(val, dict) and name == "modules":
+            for module_qname, level in val.items():
+                append_setting(module_qname, level)
+        else:
+            raise ValueError("Invalid value for setting")
+
+    return ",".join(settings)
+
+
+# Note on testing strategy:
+# This class does two things:
+# 1. Runs two versions of a test:
+#    1a. patches the env var log settings to some specific value
+#    1b. calls torch._logging.set_logs(..)
+# 2. patches the emit method of each setup handler to gather records
+# that are emitted to each console stream
+# 3. passes a ref to the gathered records to each test case for checking
+#
+# The goal of this testing in general is to ensure that given some settings env var
+# that the logs are setup correctly and capturing the correct records.
+def make_logging_test(**kwargs):
+    def wrapper(fn):
+        @inductor_config.patch({"fx_graph_cache": False})
+        def test_fn(self):
+
+            torch._dynamo.reset()
+            records = []
+            # run with env var
+            if len(kwargs) == 0:
+                with self._handler_watcher(records):
+                    fn(self, records)
+            else:
+                with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
+                    fn(self, records)
+
+            # run with API
+            torch._dynamo.reset()
+            records.clear()
+            with log_api(**kwargs), self._handler_watcher(records):
+                fn(self, records)
+
+
+        return test_fn
+
+    return wrapper
+
+def make_settings_test(settings):
+    def wrapper(fn):
+        def test_fn(self):
+            torch._dynamo.reset()
+            records = []
+            # run with env var
+            with log_settings(settings), self._handler_watcher(records):
+                fn(self, records)
+
+        return test_fn
+
+    return wrapper
+
+class LoggingTestCase(torch._dynamo.test_case.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        cls._exit_stack.enter_context(
+            unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
+        )
+        cls._exit_stack.enter_context(
+            unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
+        )
+        cls._exit_stack.enter_context(
+            unittest.mock.patch("torch._dynamo.config.verbose", False)
+        )
+
+    @classmethod
+    def tearDownClass(cls):
+        cls._exit_stack.close()
+        torch._logging._internal.log_state.clear()
+        torch._logging._init_logs()
+
+    def hasRecord(self, records, m):
+        return any(m in r.getMessage() for r in records)
+
+    def getRecord(self, records, m):
+        record = None
+        for r in records:
+            # NB: not r.msg because it looks like 3.11 changed how they
+            # structure log records
+            if m in r.getMessage():
+                self.assertIsNone(
+                    record,
+                    msg=LazyString(
+                        lambda: f"multiple matching records: {record} and {r} among {records}"
+                    ),
+                )
+                record = r
+        if record is None:
+            self.fail(f"did not find record with {m} among {records}")
+        return record
+
+    # This patches the emit method of each handler to gather records
+    # as they are emitted
+    def _handler_watcher(self, record_list):
+        exit_stack = contextlib.ExitStack()
+
+        def emit_post_hook(record):
+            nonlocal record_list
+            record_list.append(record)
+
+        # registered logs are the only ones with handlers, so patch those
+        for log_qname in torch._logging._internal.log_registry.get_log_qnames():
+            logger = logging.getLogger(log_qname)
+            num_handlers = len(logger.handlers)
+            self.assertLessEqual(
+                num_handlers,
+                2,
+                "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
+            )
+
+            self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
+
+            for handler in logger.handlers:
+                old_emit = handler.emit
+
+                def new_emit(record):
+                    old_emit(record)
+                    emit_post_hook(record)
+
+                exit_stack.enter_context(
+                    unittest.mock.patch.object(handler, "emit", new_emit)
+                )
+
+        return exit_stack
+
+
+def logs_to_string(module, log_option):
+    """Example:
+    logs_to_string("torch._inductor.compile_fx", "post_grad_graphs")
+    returns the output of TORCH_LOGS="post_grad_graphs" from the
+    torch._inductor.compile_fx module.
+    """
+    log_stream = io.StringIO()
+    handler = logging.StreamHandler(stream=log_stream)
+
+    @contextlib.contextmanager
+    def tmp_redirect_logs():
+        try:
+            logger = torch._logging.getArtifactLogger(module, log_option)
+            logger.addHandler(handler)
+            yield
+        finally:
+            logger.removeHandler(handler)
+
+    def ctx_manager():
+        exit_stack = log_settings(log_option)
+        exit_stack.enter_context(tmp_redirect_logs())
+        return exit_stack
+
+    return log_stream, ctx_manager
+
+
+def multiple_logs_to_string(module: str, *log_options: str) -> tuple[list[io.StringIO], Callable[[], AbstractContextManager[None]]]:
+    """Example:
+    multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs")
+    returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the
+    torch._inductor.compile_fx module.
+    """
+    log_streams = [io.StringIO() for _ in range(len(log_options))]
+    handlers = [logging.StreamHandler(stream=log_stream) for log_stream in log_streams]
+
+    @contextlib.contextmanager
+    def tmp_redirect_logs():
+        loggers = [torch._logging.getArtifactLogger(module, option) for option in log_options]
+        try:
+            for logger, handler in zip(loggers, handlers, strict=True):
+                logger.addHandler(handler)
+            yield
+        finally:
+            for logger, handler in zip(loggers, handlers, strict=True):
+                logger.removeHandler(handler)
+
+    def ctx_manager() -> AbstractContextManager[None]:
+        exit_stack = log_settings(", ".join(log_options))
+        exit_stack.enter_context(tmp_redirect_logs())
+        return exit_stack  # type: ignore[return-value]
+
+    return log_streams, ctx_manager
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..828585f5d3653e5d4b635259c6b241ff2d4f8575
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd8f57492d46b4645404982fb9ee64c5c817c489
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85db6ebeb321b731b55e12af455456f7622d6913
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26d3f402e741a54f21a5fca48beded5b0a58aec
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py
@@ -0,0 +1,26 @@
+# mypy: ignore-errors
+
+from torch.testing._internal.opinfo.core import OpInfo
+from torch.testing._internal.opinfo.definitions import (
+    _masked,
+    fft,
+    linalg,
+    signal,
+    special,
+)
+
+
+# Operator database
+op_db: list[OpInfo] = [
+    *fft.op_db,
+    *linalg.op_db,
+    *signal.op_db,
+    *special.op_db,
+    *_masked.op_db,
+]
+
+python_ref_db: list[OpInfo] = [
+    *fft.python_ref_db,
+    *linalg.python_ref_db,
+    *special.python_ref_db,
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f723c187776a217ac87a861a9a3964233560dfa9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfea63f3e0b7228bcdfc1d06d5a4fd2e31391990
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63215599999578848dd24090e2a5539c6be49f55
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6617fefe7681411eb473ab668903870c41b95fad
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14e0b72c9b49ddcf45094fde63a5352c14688ff9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c24486390864a023bc27c2d835e0b5c7ba633c24
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a994e5a184f890c407497320d93cdcf70f6bb89
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cca093e0dd425ac97e10fed0e38da8dd1e718465
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py
new file mode 100644
index 0000000000000000000000000000000000000000..d65fbef658a4545ae9459fc5ad561572865d96f3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -0,0 +1,1212 @@
+# mypy: ignore-errors
+
+import unittest
+from collections.abc import Sequence
+from functools import partial
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import tol, toleranceOverride
+from torch.testing._internal.common_dtype import (
+    all_types_and,
+    all_types_and_complex_and,
+    complex_types,
+    floating_and_complex_types_and,
+    floating_types_and,
+    integral_types,
+)
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    gradcheck_wrapper_masked_operation,
+    gradcheck_wrapper_masked_pointwise_operation,
+    M,
+    OpInfo,
+    ReductionOpInfo,
+    S,
+    sample_inputs_reduction,
+    SampleInput,
+)
+from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy
+
+
+# Used for log_softmax, softmax, softmin
+def sample_inputs_softmax_variant(
+    op_info,
+    device,
+    dtype,
+    requires_grad,
+    with_dtype=False,
+    use_zero_dimensions=True,
+    **kwargs,
+):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    cases = [
+        ((S,), (0,)),
+        ((S, S), (0,)),
+        ((S, S), (1,)),
+        ((S, S), (-1,)),
+        ((S, M, S), (2,)),
+        *([((S, 0, 0), (-1,))] if use_zero_dimensions else []),
+    ]
+    kwargs = dict(dtype=torch.float64) if with_dtype else None
+
+    # PyTorch on XLA throws an error when passed with dim argument for 0d tensor.
+    # See https://github.com/pytorch/xla/issues/3061 for more details.
+    if torch.device(device).type != "xla":
+        cases.append(((), (0,)))
+
+    return (
+        SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
+    )
+
+
+def _generate_masked_op_mask(input_shape, device, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=torch.bool, device=device, requires_grad=False
+    )
+    yield None
+    yield make_arg(input_shape)
+    if len(input_shape) > 2:
+        # broadcast last mask dimension:
+        yield make_arg(input_shape[:-1] + (1,))
+        # broadcast middle mask dimension:
+        yield make_arg(input_shape[:1] + (1,) + input_shape[2:])
+        # broadcast first mask dimension:
+        yield make_arg((1,) + input_shape[1:])
+        # mask.ndim < input.ndim
+        yield make_arg(input_shape[1:])
+        # mask.ndim == 1
+        yield make_arg(input_shape[-1:])
+        # masks that require broadcasting of inputs (mask.ndim >
+        # input.ndim) will not be supported, however, we may
+        # reconsider this if there will be demand on this kind of
+        # degenerate cases.
+
+
+def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked reduction operators.
+
+    Masked reduction operator is a reduction operator with trailing
+    mask optional argument. A mask is a bool tensor with the same
+    shape as input or a shape that is broadcastable to input shape.
+    """
+    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+
+    for sample_input in sample_inputs_reduction(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            sample_input_args, sample_input_kwargs = (
+                sample_input.args,
+                dict(mask=mask, **sample_input.kwargs),
+            )
+            yield SampleInput(
+                sample_input.input.detach().requires_grad_(requires_grad),
+                args=sample_input_args,
+                kwargs=sample_input_kwargs,
+            )
+            if (
+                not requires_grad
+                and dtype.is_floating_point
+                and sample_input.input.ndim == 2
+                and mask is not None
+                and mask.shape == sample_input.input.shape
+            ):
+                for v in [torch.inf, -torch.inf, torch.nan]:
+                    t = sample_input.input.detach()
+                    t.diagonal(0, -2, -1).fill_(v)
+                    yield SampleInput(
+                        t.requires_grad_(requires_grad),
+                        args=sample_input_args,
+                        kwargs=sample_input_kwargs,
+                    )
+
+
+def sample_inputs_sparse_coo_masked_reduction(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    """Sample inputs for masked reduction operators that support inputs
+    with sparse coo layouts.
+    """
+    if op_info.supports_sparse:
+        op_name = op_info.name.replace("masked.", "")
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            mask = sample_input.kwargs.get("mask")
+            if mask is not None:
+                sample_input_kwargs = sample_input.kwargs.copy()
+                sample_input_kwargs.update(mask=mask.to_sparse())
+                yield SampleInput(
+                    sample_input.input.to_sparse(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+            else:
+                if op_name in {"prod", "amax", "amin"}:
+                    # FIXME: for now reductions with non-zero reduction identity and
+                    # unspecified mask are not supported for sparse COO
+                    # tensors, see torch.masked.prod implementation
+                    # for details.
+                    continue
+                yield SampleInput(
+                    sample_input.input.to_sparse(),
+                    args=sample_input.args,
+                    kwargs=sample_input.kwargs,
+                )
+
+
+def sample_inputs_sparse_csr_masked_reduction(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    """Sample inputs for masked reduction operators that support inputs
+    with sparse csr layouts.
+    """
+    if op_info.supports_sparse_csr:
+        op_name = op_info.name.replace("masked.", "")
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            if not (
+                sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim")
+            ):
+                # - sparse CSR tensors are always 2-D tensors
+                # - masked reduction on CSR tensors are defined only if keepdim is True.
+                continue
+            mask = sample_input.kwargs.get("mask")
+            if mask is not None:
+                sample_input_kwargs = sample_input.kwargs.copy()
+                sample_input_kwargs.update(mask=mask.to_sparse_csr())
+                new_sample = SampleInput(
+                    sample_input.input.to_sparse_csr(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+            else:
+                if op_name in ["prod", "amax", "amin", "mean"]:
+                    # reductions with non-zero reduction identity and
+                    # unspecified mask is not supported for sparse CSR
+                    # tensors, see torch.masked.prod implementation
+                    # for details.
+                    continue
+                new_sample = SampleInput(
+                    sample_input.input.to_sparse_csr(),
+                    args=sample_input.args,
+                    kwargs=sample_input.kwargs,
+                )
+            yield new_sample
+            if sample_input.kwargs["dim"] == 0:
+                # Reductions of CSR tensors use different implementations for
+                # inner and/or outer dimensions. So, as a minimum of testing CSR
+                # implementations the following kwargs must be generated:
+                #   dict(dim=0, keepdim=True)
+                #   dict(dim=1, keepdim=True)
+                #   dict(dim=(0, 1), keepdim=True)
+                # Here we generate the dim=1 case from the dim=0 case.
+                sample_input_kwargs = new_sample.kwargs.copy()
+                sample_input_kwargs.update(dim=1)
+                yield SampleInput(
+                    new_sample.input.clone(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+
+
+def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked norm."""
+    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            sample_input_args, sample_input_kwargs = (
+                (ord,) + sample_input.args,
+                sample_input.kwargs.copy(),
+            )
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                args=sample_input_args,
+                kwargs=sample_input_kwargs,
+            )
+
+
+def reference_masked_std_var(
+    numpy_fn,
+):
+    ref = reference_reduction_numpy(numpy_fn)
+
+    # Translate unbiased or correction arguments into ddof
+    def func(
+        input,
+        dim=None,
+        unbiased=None,
+        *,
+        correction=None,
+        **kwargs,
+    ):
+        ddof = 1
+        if unbiased is not None:
+            ddof = 1 if unbiased else 0
+        if correction is not None:
+            ddof = correction
+
+        if isinstance(dim, Sequence):
+            dim = tuple(dim)
+
+        return ref(input, dim, ddof=ddof, **kwargs)
+
+    return func
+
+
+def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked std/var."""
+    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+    from torch.testing._internal.common_methods_invocations import sample_inputs_std_var
+
+    def masked_samples():
+        for sample_input in sample_inputs_std_var(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            if len(sample_input.args) and isinstance(sample_input.args[0], bool):
+                continue  # masked.{std, var} doesn't support `.var(unbiased)`
+
+            for mask in _generate_masked_op_mask(
+                sample_input.input.shape, device, **kwargs
+            ):
+                sample_input_args, sample_input_kwargs = (
+                    sample_input.args,
+                    dict(mask=mask, **sample_input.kwargs),
+                )
+                yield SampleInput(
+                    sample_input.input.detach().requires_grad_(requires_grad),
+                    args=sample_input_args,
+                    kwargs=sample_input_kwargs,
+                )
+                if (
+                    not requires_grad
+                    and dtype.is_floating_point
+                    and sample_input.input.ndim == 2
+                    and mask is not None
+                    and mask.shape == sample_input.input.shape
+                ):
+                    for v in [torch.inf, -torch.inf, torch.nan]:
+                        t = sample_input.input.detach()
+                        t.diagonal(0, -2, -1).fill_(v)
+                        yield SampleInput(
+                            t.requires_grad_(requires_grad),
+                            args=sample_input_args,
+                            kwargs=sample_input_kwargs,
+                        )
+
+    for sample_input in masked_samples():
+        correction = sample_input.kwargs.get("correction")
+        if correction is None:
+            correction = int(sample_input.kwargs.get("unbiased", True))
+
+        dim = sample_input.kwargs.get("dim", None)
+
+        if sample_input.kwargs.get("mask") is None:
+            orig_count = torch.masked.sum(
+                torch.ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
+            )
+        else:
+            inmask = torch.masked._input_mask(
+                sample_input.input, *sample_input.args, **sample_input.kwargs
+            )
+            orig_count = torch.masked.sum(
+                inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
+                mask=inmask,
+            )
+        if orig_count.min() <= correction + 1:
+            # Skip samples that lead to nans in var computation
+            continue
+
+        yield sample_input
+
+
+def sample_inputs_masked_softmax(
+    op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
+):
+    """Sample inputs for masked softmax, log_softmax, and softmin.
+
+    Masked normalization operator is a reduction operator with
+    trailing mask optional argument. A mask is a bool tensor with the
+    same shape as input or a shape that is broadcastable to input
+    shape.
+    """
+    for sample_input in sample_inputs_softmax_variant(
+        op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                *sample_input.args,
+                mask=mask,
+                **sample_input.kwargs,
+            )
+
+
+def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked cumsum and cumprod."""
+    for sample_input in sample_inputs_softmax_variant(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            if type(mask) is not torch.Tensor:
+                continue
+            sample_input_args, sample_input_kwargs = (
+                sample_input.args,
+                dict(mask=mask, **sample_input.kwargs),
+            )
+            if "keepdim" in sample_input_kwargs:
+                sample_input_kwargs.pop("keepdim")
+            # dimension is required
+            if sample_input_args:
+                dim = sample_input.args[0]
+            else:
+                if "dim" not in sample_input_kwargs:
+                    continue
+                dim = sample_input_kwargs.pop("dim")
+                sample_input_args = (dim,)
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                *sample_input_args,
+                **sample_input_kwargs,
+            )
+
+
+def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked logaddexp."""
+    shapes = [(S,), (S, S), (S, M, S)]
+    input_mask_lists = [
+        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
+    ]
+    other_mask_lists = [
+        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
+    ]
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    for shape, input_masks, other_masks in zip(
+        shapes, input_mask_lists, other_mask_lists, strict=True
+    ):
+        for input_mask, other_mask in zip(input_masks, other_masks, strict=True):
+            yield SampleInput(
+                make_arg(shape),
+                make_arg(shape),
+                input_mask=input_mask,
+                other_mask=other_mask,
+            )
+
+
+def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked normalize."""
+    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
+        for sample_input in sample_inputs_softmax_variant(
+            op_info, device, dtype, requires_grad, use_zero_dimensions=False, **kwargs
+        ):
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                ord,
+                *sample_input.args,
+                **sample_input.kwargs,
+            )
+
+
+op_db: list[OpInfo] = [
+    ReductionOpInfo(
+        "masked.sum",
+        ref=reference_reduction_numpy(np.sum),
+        method_variant=None,
+        identity=0,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        promotes_int_to_int64=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Failing on some jobs"),
+                "TestReductions",
+                "test_reference_masked",
+                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-03, rtol=5e-2),
+                        torch.float16: tol(atol=1e-03, rtol=5e-3),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=0.1, rtol=0.1),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestMasked",
+                "test_mask_layout",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+    ),
+    ReductionOpInfo(
+        "masked.prod",
+        ref=prod_numpy,
+        method_variant=None,
+        identity=1,
+        nan_policy="propagate",
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        promotes_int_to_int64=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.skip("Failing on some jobs"),
+                "TestReductions",
+                "test_reference_masked",
+                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
+            ),
+            DecorateInfo(
+                "TestReductions",
+                "test_ref_small_input",
+                dtypes=(torch.int8, torch.int16, torch.int32),
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                device_type="cuda",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-02)}),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_duplicate_values",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}),
+                "TestMasked",
+                "test_mask_layout",
+                device_type="cpu",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}),
+                "TestOperators",
+                "test_jvp",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+    ),
+    OpInfo(
+        "masked.cumsum",
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        method_variant=None,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        # Can reuse the same inputs; dim is required in both
+        sample_inputs_func=sample_inputs_masked_cumops,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    OpInfo(
+        "masked.cumprod",
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        method_variant=None,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
+                "TestCompositeCompliance",
+                "test_backward",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-2, rtol=2.6e-3)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ),
+        # Can reuse the same inputs; dim is required in both
+        sample_inputs_func=sample_inputs_masked_cumops,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.amax",
+        nan_policy="propagate",
+        supports_out=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        supports_sparse=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse_csr=True,
+        ref=reference_reduction_numpy(np.amax),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: amax reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: Unknown builtin op: aten::iinfo
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.amin",
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        ref=reference_reduction_numpy(np.amin),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: amax reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: Unknown builtin op: aten::iinfo
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.argmax",
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # initial is not a keyword for argmax
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_reference_masked"
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.argmin",
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # initial is not a keyword for argmin
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_reference_masked"
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.mean",
+        ref=reference_reduction_numpy(np.mean)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_sparse_csr=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        promotes_int_to_float=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-03, rtol=0.05),
+                        torch.float16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=2e-03)}),
+                "TestSparseCompressed",
+                "test_consistency",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    OpInfo(
+        "masked.median",
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        method_variant=None,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=partial(
+            sample_inputs_masked_softmax, use_zero_dimensions=False
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.norm",
+        identity=0,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        promotes_int_to_float=True,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # torch.jit.frontend.NotSupportedError: Compiled functions
+            # can't take variable number of arguments or use
+            # keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_masked_norm,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.var",
+        ref=reference_masked_std_var(np.var)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        promotes_int_to_float=True,
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=4e-5, rtol=2e-2),
+                    }
+                ),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_std_var,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        check_batched_grad=True,
+    ),
+    ReductionOpInfo(
+        "masked.std",
+        ref=reference_masked_std_var(np.std)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        promotes_int_to_float=True,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-02, rtol=1e-02),
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=5e-03, rtol=5e-04),
+                    }
+                ),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_std_var,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        check_batched_grad=True,
+    ),
+    OpInfo(
+        "masked.softmax",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.log_softmax",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+        ],
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.softmin",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME:
+            # Mismatched elements: 2 / 2 (100.0%)
+            # Greatest absolute difference: nan at index (0,) (up to 0.0001 allowed)
+            # Greatest relative difference: nan at index (0,) (up to 0.0001 allowed
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestOperators",
+                "test_vmapvjpvjp",
+                device_type="cpu",
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.normalize",
+        method_variant=None,
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_normalize,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=2e-5, rtol=6e-3)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.logaddexp",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestFwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestBwdGradients", "test_fn_gradgrad"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_logaddexp,
+        gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation,
+    ),
+    ReductionOpInfo(
+        "masked.logsumexp",
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # Identity can't be -torch.inf without overflow
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestReductions",
+                "test_empty_tensor_empty_slice",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            # all the values are the same except for -inf vs nan
+            DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"),
+            # FIXME:
+            # Mismatched elements: 2 / 12 (16.7%)
+            # Greatest absolute difference: 9223372034707292160 at index (0, 0, 0, 0)
+            # Greatest relative difference: 0.0 at index (0, 0, 0, 1)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cpu",
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py
new file mode 100644
index 0000000000000000000000000000000000000000..8293fca978f262d7bf6eea6b546b2c3cd500f227
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py
@@ -0,0 +1,809 @@
+# mypy: ignore-errors
+
+import unittest
+from functools import partial
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import SM53OrLater
+from torch.testing._internal.common_device_type import precisionOverride
+from torch.testing._internal.common_dtype import (
+    all_types_and,
+    all_types_and_complex_and,
+)
+from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    ErrorInput,
+    OpInfo,
+    sample_inputs_spectral_ops,
+    SampleInput,
+    SpectralFuncInfo,
+    SpectralFuncType,
+)
+from torch.testing._internal.opinfo.refs import (
+    _find_referenced_opinfo,
+    _inherit_constructor_args,
+    PythonRefInfo,
+)
+
+
+has_scipy_fft = False
+if TEST_SCIPY:
+    try:
+        import scipy.fft
+
+        has_scipy_fft = True
+    except ModuleNotFoundError:
+        pass
+
+
+class SpectralFuncPythonRefInfo(SpectralFuncInfo):
+    """
+    An OpInfo for a Python reference of an elementwise unary operation.
+    """
+
+    def __init__(
+        self,
+        name,  # the stringname of the callable Python reference
+        *,
+        op=None,  # the function variant of the operation, populated as torch. if None
+        torch_opinfo_name,  # the string name of the corresponding torch opinfo
+        torch_opinfo_variant="",
+        **kwargs,
+    ):  # additional kwargs override kwargs inherited from the torch opinfo
+        self.torch_opinfo_name = torch_opinfo_name
+        self.torch_opinfo = _find_referenced_opinfo(
+            torch_opinfo_name, torch_opinfo_variant, op_db=op_db
+        )
+        assert isinstance(self.torch_opinfo, SpectralFuncInfo)
+
+        inherited = self.torch_opinfo._original_spectral_func_args
+        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
+
+        super().__init__(**ukwargs)
+
+
+def error_inputs_fft(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    # Zero-dimensional tensor has no dimension to take FFT of
+    yield ErrorInput(
+        SampleInput(make_arg()),
+        error_type=IndexError,
+        error_regex="Dimension specified as -1 but tensor has no dimensions",
+    )
+
+
+def error_inputs_fftn(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    # Specifying a dimension on a zero-dimensional tensor
+    yield ErrorInput(
+        SampleInput(make_arg(), dim=(0,)),
+        error_type=IndexError,
+        error_regex="Dimension specified as 0 but tensor has no dimensions",
+    )
+
+
+def sample_inputs_fft_with_min(
+    op_info, device, dtype, requires_grad=False, *, min_size, **kwargs
+):
+    yield from sample_inputs_spectral_ops(
+        op_info, device, dtype, requires_grad, **kwargs
+    )
+    if TEST_WITH_ROCM:
+        # FIXME: Causes floating point exception on ROCm
+        return
+
+    # Check the "Invalid number of data points" error isn't too strict
+    # https://github.com/pytorch/pytorch/pull/109083
+    a = make_tensor(min_size, dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(a)
+
+
+def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
+    def mt(shape, **kwargs):
+        return make_tensor(
+            shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+        )
+
+    yield SampleInput(mt((9, 10)))
+    yield SampleInput(mt((50,)), kwargs=dict(dim=0))
+    yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,)))
+    yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
+    yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))
+
+
+# Operator database
+op_db: list[OpInfo] = [
+    SpectralFuncInfo(
+        "fft.fft",
+        aten_name="fft_fft",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.fft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.fft2",
+        aten_name="fft_fft2",
+        ref=np.fft.fft2,
+        decomp_aten_name="_fft_c2c",
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_complex_half_reference_testing",
+                device_type="cuda",
+                dtypes=[torch.complex32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.fftn",
+        aten_name="fft_fftn",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.fftn,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
+    ),
+    SpectralFuncInfo(
+        "fft.hfft",
+        aten_name="fft_hfft",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.hfft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=2),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        check_batched_gradgrad=False,
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.hfft2",
+        aten_name="fft_hfft2",
+        decomp_aten_name="_fft_c2r",
+        ref=scipy.fft.hfft2 if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_gradgrad=False,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+        ],
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+            ),
+            # FIXME: errors are too large; needs investigation
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_complex_half_reference_testing",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.hfftn",
+        aten_name="fft_hfftn",
+        decomp_aten_name="_fft_c2r",
+        ref=scipy.fft.hfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_gradgrad=False,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+        ],
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.rfft",
+        aten_name="fft_rfft",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        skips=(),
+        check_batched_gradgrad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.rfft2",
+        aten_name="fft_rfft2",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfft2,
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            precisionOverride({torch.float: 1e-4}),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.rfftn",
+        aten_name="fft_rfftn",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfftn,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            precisionOverride({torch.float: 1e-4}),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ifft",
+        aten_name="fft_ifft",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.ifft2",
+        aten_name="fft_ifft2",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifft2,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ifftn",
+        aten_name="fft_ifftn",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifftn,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ihfft",
+        aten_name="fft_ihfft",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.ihfft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fft,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        skips=(),
+        check_batched_grad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.ihfft2",
+        aten_name="fft_ihfft2",
+        decomp_aten_name="_fft_r2c",
+        ref=scipy.fft.ihfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=(
+            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
+            ),
+            # Mismatched elements!
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.ihfftn",
+        aten_name="fft_ihfftn",
+        decomp_aten_name="_fft_r2c",
+        ref=scipy.fft.ihfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
+            # Mismatched elements!
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
+            ),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.irfft",
+        aten_name="fft_irfft",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.irfft2",
+        aten_name="fft_irfft2",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfft2,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.irfftn",
+        aten_name="fft_irfftn",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfftn,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    OpInfo(
+        "fft.fftshift",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.half, torch.chalf
+        ),
+        sample_inputs_func=sample_inputs_fftshift,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    OpInfo(
+        "fft.ifftshift",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.half, torch.chalf
+        ),
+        sample_inputs_func=sample_inputs_fftshift,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fft",
+        torch_opinfo_name="fft.fft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifft",
+        torch_opinfo_name="fft.ifft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfft",
+        torch_opinfo_name="fft.rfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfft",
+        torch_opinfo_name="fft.irfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfft",
+        torch_opinfo_name="fft.hfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfft",
+        torch_opinfo_name="fft.ihfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fftn",
+        torch_opinfo_name="fft.fftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifftn",
+        torch_opinfo_name="fft.ifftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfftn",
+        torch_opinfo_name="fft.rfftn",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfftn",
+        torch_opinfo_name="fft.irfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfftn",
+        torch_opinfo_name="fft.hfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfftn",
+        torch_opinfo_name="fft.ihfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+            # AssertionError: Reference result was farther (0.09746177145360499) from the precise
+            # computation than the torch result was (0.09111555632069855)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_torch_fallback",
+                dtypes=(torch.float16,),
+                device_type="cuda",
+            ),
+            # AssertionError: Reference result was farther (0.0953431016138116) from the precise
+            # computation than the torch result was (0.09305490684430734)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_executor",
+                dtypes=(torch.float16,),
+                device_type="cuda",
+            ),
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fft2",
+        torch_opinfo_name="fft.fft2",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifft2",
+        torch_opinfo_name="fft.ifft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfft2",
+        torch_opinfo_name="fft.rfft2",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfft2",
+        torch_opinfo_name="fft.irfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfft2",
+        torch_opinfo_name="fft.hfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfft2",
+        torch_opinfo_name="fft.ihfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+            # FIXME:
+            # Reference result was farther (0.0953431016138116) from the precise computation
+            # than the torch result was (0.09305490684430734)!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_executor",
+                device_type="cuda",
+            ),
+        ],
+    ),
+    PythonRefInfo(
+        "_refs.fft.fftshift",
+        op_db=op_db,
+        torch_opinfo_name="fft.fftshift",
+    ),
+    PythonRefInfo(
+        "_refs.fft.ifftshift",
+        op_db=op_db,
+        torch_opinfo_name="fft.ifftshift",
+    ),
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py
new file mode 100644
index 0000000000000000000000000000000000000000..f41cadad67eb780aa6980306002a27cacfd2eb30
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py
@@ -0,0 +1,2392 @@
+# mypy: ignore-errors
+
+import itertools
+import random
+import unittest
+from collections.abc import Iterable
+from functools import partial
+from itertools import chain, product
+
+import numpy as np
+from numpy import inf
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import _get_magma_version, with_tf32_off
+from torch.testing._internal.common_device_type import (
+    has_cusolver,
+    skipCPUIfNoLapack,
+    skipCUDAIfNoCusolver,
+    skipCUDAIfNoMagma,
+    skipCUDAIfNoMagmaAndNoCusolver,
+    skipCUDAIfNoMagmaAndNoLinalgsolver,
+    skipCUDAIfRocm,
+    tol,
+    toleranceOverride,
+)
+from torch.testing._internal.common_dtype import (
+    all_types_and_complex,
+    all_types_and_complex_and,
+    floating_and_complex_types,
+    floating_and_complex_types_and,
+)
+from torch.testing._internal.common_utils import (
+    GRADCHECK_NONDET_TOL,
+    make_fullrank_matrices_with_distinct_singular_values,
+    skipIfSlowGradcheckEnv,
+    slowTest,
+    TEST_WITH_ROCM,
+    TEST_XPU,
+)
+from torch.testing._internal.opinfo.core import (
+    clone_sample,
+    DecorateInfo,
+    ErrorInput,
+    gradcheck_wrapper_hermitian_input,
+    L,
+    M,
+    OpInfo,
+    ReductionOpInfo,
+    S,
+    SampleInput,
+)
+from torch.testing._internal.opinfo.refs import PythonRefInfo, ReductionPythonRefInfo
+
+
+def sample_kwargs_vector_norm(t, **kwargs):
+    # orders with / without identity
+    def ords():
+        has_id = (6, 4, 2, 1, 0, 0.9)
+        no_id = (inf, -2.1, -inf)
+        if t.numel() == 0:
+            dim = kwargs.get("dim")
+            if dim is None:
+                return has_id
+            if not isinstance(dim, Iterable):
+                dim = (dim,)
+            for d in dim:
+                if t.size(d) == 0:
+                    return has_id
+        return has_id + no_id
+
+    return (((), dict(ord=o)) for o in ords())
+
+
+def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    is_linalg_svd = "linalg.svd" in op_info.name
+    batches = [(), (0,), (3,)]
+    ns = [0, 3, 5]
+
+    def uniformize(usv):
+        S = usv[1]
+        k = S.shape[-1]
+        U = usv[0][..., :k]
+        Vh = usv[2] if is_linalg_svd else usv[2].mH
+        Vh = Vh[..., :k, :]
+        return U, S, Vh
+
+    def fn_U(usv):
+        U, _, _ = uniformize(usv)
+        return U.abs()
+
+    def fn_S(usv):
+        return uniformize(usv)[1]
+
+    def fn_Vh(usv):
+        # We also return S to test
+        _, S, Vh = uniformize(usv)
+        return S, Vh.abs()
+
+    def fn_UVh(usv):
+        U, S, Vh = uniformize(usv)
+        return U @ Vh, S
+
+    fns = (fn_U, fn_S, fn_Vh, fn_UVh)
+
+    fullmat = "full_matrices" if is_linalg_svd else "some"
+
+    for batch, n, k, fullmat_val, fn in product(batches, ns, ns, (True, False), fns):
+        shape = batch + (n, k)
+        yield SampleInput(
+            make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=fn
+        )
+
+
+def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
+    yield SampleInput(
+        make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1)
+    )
+    yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
+
+
+def error_inputs_cross(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
+    err = "inputs dimension -1 must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
+    err = "inputs must have the same number of dimensions"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),))
+    err = "must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(
+        input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2)
+    )
+    err = "Dimension out of range"
+    yield ErrorInput(sample, error_regex=err, error_type=IndexError)
+
+
+def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
+    """
+    This function generates input for torch.linalg.householder_product (torch.orgqr).
+    The first argument should be a square matrix or batch of square matrices, the second argument is a vector or batch of vectors.
+    Empty, square, rectangular, batched square and batched rectangular input is generated.
+    """
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        low=-2,
+        high=2,
+    )
+    # Each column of the matrix is getting multiplied many times leading to very large values for
+    # the Jacobian matrix entries and making the finite-difference result of grad check less accurate.
+    # That's why gradcheck with the default range [-9, 9] fails and [-2, 2] is used here.
+    yield SampleInput(make_arg((S, S)), make_arg((S,)))
+    yield SampleInput(make_arg((S + 1, S)), make_arg((S,)))
+    yield SampleInput(make_arg((2, 1, S, S)), make_arg((2, 1, S)))
+    yield SampleInput(make_arg((2, 1, S + 1, S)), make_arg((2, 1, S)))
+    yield SampleInput(
+        make_arg((0, 0), low=None, high=None),
+        make_arg((0,), low=None, high=None),
+    )
+    yield SampleInput(make_arg((S, S)), make_arg((0,), low=None, high=None))
+    # m = n = S, k = S - 2
+    yield SampleInput(make_arg((S, S)), make_arg((S - 2,), low=None, high=None))
+    # m = S, n = S -1, k = S - 2
+    yield SampleInput(make_arg((S, S - 1)), make_arg((S - 2,), low=None, high=None))
+
+
+def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad, **kwargs):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    make_arg_fullrank = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    # (, ())
+    test_sizes = [
+        (1, ()),
+        (2, (0,)),
+        (2, (2,)),
+    ]
+
+    for matrix_size, batch_sizes in test_sizes:
+        size = batch_sizes + (matrix_size, matrix_size)
+        for n in (0, 3, 5):
+            yield SampleInput(make_arg(size), args=(n,))
+        for n in [-4, -2, -1]:
+            yield SampleInput(make_arg_fullrank(*size), args=(n,))
+
+
+def sample_inputs_linalg_det_logdet_slogdet(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    batches = [(), (0,), (3,)]
+    ns = [0, 1, 5]
+
+    is_logdet = op_info.name == "logdet"
+
+    for (
+        batch,
+        n,
+    ) in product(batches, ns):
+        shape = batch + (n, n)
+        A = make_arg(*shape)
+        # Need to make the matrices in A have positive determinant for autograd
+        # To do so, we multiply A by its determinant to flip the sign of its determinant
+        if is_logdet and not A.is_complex() and A.numel() > 0:
+            s = torch.linalg.slogdet(A).sign
+            A = A * s.unsqueeze(-1).unsqueeze(-1)
+            A.requires_grad_(requires_grad)
+        yield SampleInput(A)
+
+
+def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    """Samples the inputs for both linalg.lu_solve and lu_solve"""
+    make_fn = make_fullrank_matrices_with_distinct_singular_values
+    make_a = partial(make_fn, dtype=dtype, device=device)
+    make_b = partial(make_tensor, dtype=dtype, device=device)
+
+    def clone(X, requires_grad):
+        Y = X.clone()
+        Y.requires_grad_(requires_grad)
+        return Y
+
+    is_linalg_lu_solve = op_info.name == "linalg.lu_solve"
+
+    batches = ((), (0,), (2,))
+    ns = (3, 1, 0)
+    nrhs = (4, 1, 0)
+
+    for n, batch, rhs in product(ns, batches, nrhs):
+        A = make_a(*(batch + (n, n)))
+        if torch.device(device).type == "mps":
+            # TODO: Fix lu_factor for MPS, because it does not work for all of
+            # these cases. So we resort to the CPU impl here and move the
+            # outputs back to MPS.
+            LU, pivots = (x.to(device) for x in torch.linalg.lu_factor(A.cpu()))
+        else:
+            LU, pivots = torch.linalg.lu_factor(A)
+
+        B = make_b(batch + (n, rhs))
+
+        grads = (False,) if not requires_grad else (True, False)
+        # we try all possible combinations of requires_grad for each input
+        for LU_grad, B_grad in product(grads, grads):
+            # when requires_grad == True, at least one input has to have requires_grad enabled
+            if requires_grad and not LU_grad and not B_grad:
+                continue
+
+            if is_linalg_lu_solve:
+                for adjoint, left in product((True, False), repeat=2):
+                    yield SampleInput(
+                        clone(LU, LU_grad),
+                        args=(pivots, clone(B if left else B.mT, B_grad)),
+                        kwargs=dict(adjoint=adjoint, left=left),
+                    )
+            else:
+                yield SampleInput(clone(B, B_grad), args=(clone(LU, LU_grad), pivots))
+
+
+def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
+    # Each test case consists of the sizes in the chain of multiplications
+    # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5)
+    test_cases = [
+        [1, 2, 1],
+        [2, 0, 2],
+        [0, 2, 2],
+        [2, 2, 2, 2],
+        [2, 3, 4, 5],
+        [5, 4, 0, 2],
+        [2, 4, 3, 5, 3, 2],
+    ]
+
+    for sizes in test_cases:
+        tensors = []
+        for size in itertools.pairwise(sizes):
+            t = make_tensor(
+                size, dtype=dtype, device=device, requires_grad=requires_grad
+            )
+            tensors.append(t)
+        yield SampleInput(tensors)
+
+
+def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs):
+    low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    sizes = ((2, 2), (2, 3, 2))
+    if dtype in low_precision_dtypes:
+        # svdvals not supported for low precision dtypes
+        ords = ("fro", inf, -inf, 1, -1)
+    else:
+        ords = ("fro", "nuc", inf, -inf, 1, -1, 2, -2)
+    dims = ((-2, -1), (-1, 0))
+
+    for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]):
+        yield SampleInput(make_arg(size), args=(ord, dim, keepdim))
+
+
+def sample_inputs_linalg_norm(
+    op_info, device, dtype, requires_grad, *, variant=None, **kwargs
+):
+    if variant is not None and variant != "subgradient_at_zero":
+        raise ValueError(
+            f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
+        )
+
+    test_sizes = [
+        (S,),
+        (0,),
+        (S, S),
+        (0, 0),
+        (S, 0),
+        (0, S),
+        (S, S, S),
+        (0, S, S),
+        (S, 0, S),
+        (0, 0, 0),
+    ]
+
+    vector_ords = (None, 0, 0.5, 1, 2, 3.5, inf, -0.5, -1, -2, -3.5, -inf)
+    if dtype in {torch.float16, torch.bfloat16, torch.complex32}:
+        # svdvals not supported for low precision dtypes
+        matrix_ords = ("fro", inf, -inf, 1, -1)
+    else:
+        matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2)
+
+    make_arg = partial(
+        make_tensor,
+        dtype=dtype,
+        device=device,
+        requires_grad=requires_grad,
+        low=None,
+        high=None,
+    )
+
+    for test_size in test_sizes:
+        is_vector_norm = len(test_size) == 1
+        is_matrix_norm = len(test_size) == 2
+
+        # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
+        is_valid_for_p2 = is_vector_norm or (test_size[-1] != 0 and test_size[-2] != 0)
+
+        for keepdim in [False, True]:
+            if variant != "subgradient_at_zero" and is_valid_for_p2:
+                yield SampleInput(make_arg(test_size), keepdim=keepdim)
+
+            if not (is_vector_norm or is_matrix_norm):
+                continue
+
+            ords = vector_ords if is_vector_norm else matrix_ords
+
+            for ord in ords:
+                if is_vector_norm and test_size[-1] == 0:
+                    if ord == np.inf or (ord is not None and ord < 0):
+                        # RuntimeError: linalg.vector_norm cannot compute the
+                        # {ord} norm on an empty tensor because the operation
+                        # does not have an identity
+                        continue
+                elif is_matrix_norm:
+                    dims_to_check = {
+                        None: (0,),
+                        -1: (1,),
+                        -2: (0, 1),
+                        -np.inf: (0,),
+                    }.get(ord, ())
+
+                    if any(test_size[d] == 0 for d in dims_to_check):
+                        # IndexError: amax(): Expected reduction dim {dim} to
+                        # have non-zero size.
+                        continue
+
+                    no_grad_dims_to_check = {
+                        np.inf: (0,),
+                        2: (0, 1),
+                        1: (1,),
+                    }.get(ord, ())
+
+                    if (
+                        any(test_size[d] == 0 for d in no_grad_dims_to_check)
+                        and requires_grad
+                    ):
+                        continue
+
+                if variant == "subgradient_at_zero":
+                    yield SampleInput(
+                        torch.zeros(
+                            test_size,
+                            dtype=dtype,
+                            device=device,
+                            requires_grad=requires_grad,
+                        ),
+                        ord,
+                        keepdim=keepdim,
+                    )
+                else:
+                    yield SampleInput(make_arg(test_size), ord, keepdim=keepdim)
+
+                    if ord in ["nuc", "fro"]:
+                        yield SampleInput(
+                            make_arg(test_size), ord=ord, keepdim=keepdim, dim=(0, 1)
+                        )
+
+
+def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    batches = ((), (0,), (1,), (5,))
+    ns = (0, 1, 3, 5)
+    for b, n in product(batches, ns):
+        shape = b + (n,)
+        yield SampleInput(make_arg(shape), args=(make_arg(shape),))
+        for i in range(len(shape)):
+            yield SampleInput(
+                make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i)
+            )
+
+
+def sample_inputs_linalg_invertible(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates invertible inputs for linear algebra ops
+    The input is generated as the itertools.product of 'batches' and 'ns'.
+    In total this function generates 8 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices,
+        (1, 1) - 1x1 batch of matrices
+    'ns' gives 0x0 and 5x5 matrices.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    """
+    make_fn = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 0]
+
+    for batch, n in product(batches, ns):
+        yield SampleInput(make_arg(*batch, n, n))
+
+
+def sample_inputs_matrix_rank(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function produces inputs for matrix rank that test
+    all possible combinations for atol and rtol
+    """
+
+    def make_tol_arg(kwarg_type, inp):
+        if kwarg_type == "none":
+            return None
+        if kwarg_type == "float":
+            return 1.0
+        assert kwarg_type == "tensor"
+        return torch.ones(inp.shape[:-2], device=device)
+
+    for tol_type in ["float", "tensor"]:
+        for atol_type, rtol_type in product(["none", tol_type], repeat=2):
+            if (
+                not atol_type and not rtol_type
+            ):  # default behavior, so skipped here so it's not tested 2 extra times
+                continue
+            for sample in sample_inputs_linalg_invertible(
+                op_info, device, dtype, requires_grad
+            ):
+                assert sample.kwargs == {}
+                sample.kwargs = {
+                    "atol": make_tol_arg(atol_type, sample.input),
+                    "rtol": make_tol_arg(rtol_type, sample.input),
+                }
+                yield sample
+
+    # default kwargs
+    yield from sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+
+
+def sample_inputs_linalg_pinv_singular(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function produces factors `a` and `b` to generate inputs of the form `a @ b.t()` to
+    test the backward method of `linalg_pinv`. That way we always preserve the rank of the
+    input no matter the perturbations applied to it by the gradcheck.
+    Note that `pinv` is Frechet-differentiable in a rank-preserving neighborhood.
+    """
+    batches = [(), (0,), (2,), (1, 1)]
+    # the size of at least 30 is required to cause failures for the previous implicit implementation
+    # of the pinv's backward method, albeit it is slow.
+    size = [0, 3, 50]
+
+    for batch, m, n in product(batches, size, size):
+        for k in range(min(3, m, n)):
+            # Note that by making the columns of `a` and `b` orthonormal we make sure that
+            # the product matrix `a @ b.t()` has condition number 1 when restricted to its image
+            a = (
+                torch.rand(*batch, m, k, device=device, dtype=dtype)
+                .qr()
+                .Q.requires_grad_(requires_grad)
+            )
+            b = (
+                torch.rand(*batch, n, k, device=device, dtype=dtype)
+                .qr()
+                .Q.requires_grad_(requires_grad)
+            )
+            yield SampleInput(a, args=(b,))
+
+
+def sample_inputs_linalg_cond(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    # autograd is not supported for inputs with zero number of elements
+    shapes = (
+        (S, S),
+        (2, S, S),
+        (2, 1, S, S),
+    )
+
+    for shape in shapes:
+        yield SampleInput(make_arg(shape))
+
+
+def sample_inputs_linalg_vander(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    shapes = (
+        (),
+        (1,),
+        (S,),
+        (2, S),
+    )
+
+    for shape in shapes:
+        if len(shape) > 0 and shape[-1] > 1:
+            yield SampleInput(make_arg(shape))
+        n = shape[-1] if len(shape) > 0 else 1
+        for i in range(3):
+            # n-1, n, n+1
+            N = n + i - 1
+            if N < 2:
+                continue
+            yield SampleInput(make_arg(shape), kwargs=dict(N=N))
+
+
+def np_vander_batched(x, N=None):
+    # Wrapper around np.vander that supports batches of 1 dimension (enough for the tests)
+    if x.ndim == 0:
+        x = x[np.newaxis]
+    if x.ndim == 1:
+        y = np.vander(x, N=N, increasing=True)
+        return y
+    else:
+        if N is None:
+            N = x.shape[-1]
+        y = np.vander(x.ravel(), N=N, increasing=True).reshape((*x.shape, N))
+        return y
+
+
+def sample_inputs_linalg_cholesky_inverse(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    from torch.testing._internal.common_utils import random_well_conditioned_matrix
+
+    # Cholesky factorization is for positive-definite matrices
+    single_well_conditioned_matrix = random_well_conditioned_matrix(
+        S, S, dtype=dtype, device=device
+    )
+    batch_well_conditioned_matrices = random_well_conditioned_matrix(
+        2, S, S, dtype=dtype, device=device
+    )
+    single_pd = single_well_conditioned_matrix @ single_well_conditioned_matrix.mH
+    batch_pd = batch_well_conditioned_matrices @ batch_well_conditioned_matrices.mH
+
+    inputs = (
+        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
+        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
+        single_pd,
+        batch_pd,
+    )
+    test_cases = (torch.linalg.cholesky(a, upper=False) for a in inputs)
+    for l in test_cases:
+        # generated lower-triangular samples
+        l.requires_grad = requires_grad
+        yield SampleInput(l)  # upper=False by default
+        yield SampleInput(
+            l.detach().clone().requires_grad_(requires_grad), kwargs=dict(upper=False)
+        )
+
+        # generate upper-triangular inputs
+        u = l.detach().clone().mT.contiguous().requires_grad_(requires_grad)
+        yield SampleInput(u, kwargs=dict(upper=True))
+
+
+def sample_inputs_linalg_ldl_factor(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    from torch.testing._internal.common_utils import (
+        random_hermitian_pd_matrix,
+        random_symmetric_pd_matrix,
+    )
+
+    device = torch.device(device)
+
+    # Symmetric inputs
+    yield SampleInput(
+        random_symmetric_pd_matrix(S, dtype=dtype, device=device),
+        kwargs=dict(hermitian=False),
+    )  # single matrix
+    yield SampleInput(
+        random_symmetric_pd_matrix(S, 2, dtype=dtype, device=device),
+        kwargs=dict(hermitian=False),
+    )  # batch of matrices
+    yield SampleInput(
+        torch.zeros(0, 0, dtype=dtype, device=device), kwargs=dict(hermitian=False)
+    )  # 0x0 matrix
+    yield SampleInput(
+        torch.zeros(0, 2, 2, dtype=dtype, device=device), kwargs=dict(hermitian=False)
+    )  # zero batch of matrices
+
+    # Hermitian inputs
+    # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
+    magma_254_available = device.type == "cuda" and _get_magma_version() >= (2, 5, 4)
+    if dtype.is_complex and (device.type == "cpu" or magma_254_available):
+        yield SampleInput(
+            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
+            kwargs=dict(hermitian=True),
+        )  # single matrix
+        yield SampleInput(
+            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
+            kwargs=dict(hermitian=True),
+        )  # batch of matrices
+
+
+def sample_inputs_linalg_ldl_solve(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    # Generate LDL factors of symmetric (and Hermitian on CPU) matrices
+    from torch.testing._internal.common_utils import (
+        random_hermitian_pd_matrix,
+        random_symmetric_pd_matrix,
+    )
+
+    device = torch.device(device)
+    symmetric_inputs = (
+        random_symmetric_pd_matrix(S, dtype=dtype, device=device),  # single matrix
+        random_symmetric_pd_matrix(
+            S, 2, dtype=dtype, device=device
+        ),  # batch of matrices
+        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
+        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
+    )
+    hermitian_inputs = (
+        (
+            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
+            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
+        )
+        if device.type == "cpu" and dtype.is_complex
+        else ()
+    )
+    test_cases1 = (
+        torch.linalg.ldl_factor_ex(a, hermitian=False) for a in symmetric_inputs
+    )
+    test_cases2 = (
+        torch.linalg.ldl_factor_ex(a, hermitian=True) for a in hermitian_inputs
+    )
+
+    # Symmetric case
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    for test_case in test_cases1:
+        factors, pivots, _ = test_case
+        factors.requires_grad = requires_grad
+        for B_batch_shape in ((), factors.shape[:-2]):
+            B = make_arg((*B_batch_shape, factors.shape[-1], S))
+            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False))
+            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
+            yield SampleInput(
+                clone_factors, args=(pivots, B), kwargs=dict(hermitian=False)
+            )
+
+    # Hermitian case
+    for test_case in test_cases2:
+        factors, pivots, _ = test_case
+        factors.requires_grad = requires_grad
+        for B_batch_shape in ((), factors.shape[:-2]):
+            B = make_arg((*B_batch_shape, factors.shape[-1], S))
+            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True))
+            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
+            yield SampleInput(
+                clone_factors, args=(pivots, B), kwargs=dict(hermitian=True)
+            )
+
+
+def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs):
+    from torch.testing._internal.common_utils import random_well_conditioned_matrix
+
+    device = torch.device(device)
+
+    drivers: tuple[str, ...]
+    if device.type == "cuda":
+        drivers = ("gels",)
+    else:
+        drivers = ("gels", "gelsy", "gelss", "gelsd")
+
+    # we generate matrices of shape (..., n + delta, n)
+    deltas: tuple[int, ...]
+    if device.type == "cpu" or has_cusolver():
+        deltas = (-1, 0, +1)
+    # only square systems if Cusolver is not available
+    # because we solve a lstsq problem with a transposed matrix in the backward
+    else:
+        deltas = (0,)
+
+    for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas):
+        shape = batch + (3 + delta, 3)
+        a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
+        a.requires_grad_(requires_grad)
+        b = make_tensor(
+            shape,
+            dtype=dtype,
+            device=device,
+            low=None,
+            high=None,
+            requires_grad=requires_grad,
+        )
+        yield SampleInput(a, b, driver=driver)
+
+
+def error_inputs_lstsq(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(
+        SampleInput(zero_d, args=(zero_d,)),
+        error_type=RuntimeError,
+        error_regex="at least 2 dimensions",
+    )
+
+
+def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(
+        SampleInput(zero_d, args=(zero_d, None)),
+        error_type=RuntimeError,
+        error_regex="at least 2 dimensions",
+    )
+
+
+def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    # Shapes for 2D Tensors
+    shapes_2d = ((S, S), (3, 5), (5, 3))
+
+    # Shapes for 3D Tensors
+    shapes_3d = ((S, S, S),)
+
+    kwargs_2d = ({}, dict(offset=2), dict(offset=2), dict(offset=1))
+    kwargs_3d = (
+        dict(offset=1, dim1=1, dim2=2),
+        dict(offset=2, dim1=0, dim2=1),
+        dict(offset=-2, dim1=0, dim2=1),
+    )
+
+    for shape, kwarg in chain(
+        product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)
+    ):
+        yield SampleInput(make_arg(shape), kwargs=kwarg)
+
+
+def error_inputs_diagonal_diag_embed(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    shapes1d = (0, 1, (0,), (1,))
+    shapes2d = ((M, L),)
+    shapes3d = ((M, S, L),)
+
+    kwargs1d = {}
+
+    kwargs2d = (
+        # dim1 == dim2 is not allowed
+        dict(dim1=1, dim2=1),
+        # out of bounds dims are not allowed
+        dict(dim1=10000),
+        dict(dim2=10000),
+    )
+
+    kwargs3d = kwargs2d
+
+    samples1d = product(shapes1d, kwargs1d)
+    samples2d = product(shapes2d, kwargs2d)
+    samples3d = product(shapes3d, kwargs3d)
+
+    for shape, kwargs in chain(samples1d, samples2d, samples3d):
+        arg = make_arg(shape)
+        sample = SampleInput(input=arg, kwargs=kwargs)
+
+        dim1 = kwargs.get("dim1")
+        dim2 = kwargs.get("dim2")
+
+        if "diagonal" in op_info.name:
+            num_dim = arg.dim()
+        elif op_info.name in ("diag_embed", "_refs.diag_embed"):
+            # these are valid inputs for diag_embed
+            if shape in ((0,), (1,)):
+                continue
+            num_dim = arg.dim() + 1
+        else:
+            raise RuntimeError("should be unreachable")
+
+        bound1 = -num_dim
+        bound2 = num_dim - 1
+        dim_range = range(bound1, bound2 + 1)
+        dim1_cond = dim1 and dim1 not in dim_range
+        dim2_cond = dim2 and dim2 not in dim_range
+
+        if dim1 == dim2:
+            err = f"diagonal dimensions cannot be identical {dim1}, {dim2}"
+            yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+        elif dim1_cond or dim2_cond:
+            err_dim = dim1 if dim1_cond else dim2
+            err = (
+                r"Dimension out of range \(expected to be in range of "
+                rf"\[{bound1}, {bound2}\], but got {err_dim}\)"
+            )
+            yield ErrorInput(sample, error_regex=err, error_type=IndexError)
+        else:
+            raise RuntimeError("should be unreachable")
+
+
+def sample_inputs_linalg_cholesky(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates always positive-definite input for torch.linalg.cholesky using
+    random_hermitian_pd_matrix.
+    The input is generated as the itertools.product of 'batches' and 'ns'.
+    In total this function generates 8 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices,
+        (1, 1) - 1x1 batch of matrices
+    'ns' gives 0x0 and 5x5 matrices.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    """
+    from torch.testing._internal.common_utils import random_hermitian_pd_matrix
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 0]
+    for batch, n, upper in product(batches, ns, [True, False]):
+        a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
+        a.requires_grad = requires_grad
+        yield SampleInput(a, upper=upper)
+
+
+def sample_inputs_linalg_eig(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.eig
+    """
+
+    def out_fn(output):
+        return output[0], abs(output[1])
+
+    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+    for sample in samples:
+        sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument.
+    """
+
+    def out_fn(output):
+        if isinstance(output, tuple):
+            # eigh function
+            return output[0], abs(output[1])
+        else:
+            # eigvalsh function
+            return output
+
+    # Samples do not need to be Hermitian, as we're using gradcheck_wrapper_hermitian_input
+    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+    for sample in samples:
+        # Note: we cannot use np.random.choice here as TorchDynamo
+        # does not support tensors of strings.
+        sample.kwargs = {"UPLO": random.choice(["L", "U"])}
+        sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.pinv with hermitian=False keyword argument.
+    """
+    for o in sample_inputs_linalg_invertible(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        real_dtype = o.input.real.dtype if dtype.is_complex else dtype
+        # requires_grad path for rtol tensor is not implemented
+        for rtol in (None, 1.0, torch.tensor(1.0, dtype=real_dtype, device=device)):
+            o = clone_sample(o)
+            o.kwargs = {"rtol": rtol}
+            yield o
+
+
+def sample_inputs_linalg_pinv_hermitian(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates input for torch.linalg.pinv with hermitian=True keyword argument.
+    """
+    for o in sample_inputs_linalg_invertible(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        o.kwargs = {"hermitian": True}
+        yield o
+
+
+def sample_inputs_linalg_solve(
+    op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs
+):
+    """
+    This function generates always solvable input for torch.linalg.solve
+    We sample a fullrank square matrix (i.e. invertible) A
+    The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'.
+    The second input is generated as the product of 'batches', 'ns' and 'nrhs'.
+    In total this function generates 18 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices.
+    'ns' gives 0x0 and 5x5 matrices.
+    and 'nrhs' controls the number of vectors to solve for:
+        () - using 1 as the number of vectors implicitly
+        (1,) - same as () but explicit
+        (3,) - solve for 3 vectors.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    'vector_rhs_allowed' controls whether to include nrhs = () to the list of SampleInputs.
+    torch.solve / triangular_solve / cholesky_solve (opposed to torch.linalg.solve) do not allow
+    1D tensors (vectors) as the right-hand-side.
+    Once torch.solve / triangular_solve / cholesky_solve and its testing are removed,
+    'vector_rhs_allowed' may be removed here as well.
+    """
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_a = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    make_b = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (2, 2)]
+    ns = [5, 0]
+    if vector_rhs_allowed:
+        nrhs = [(), (1,), (3,)]
+    else:
+        nrhs = [(1,), (3,)]
+
+    for n, batch, rhs in product(ns, batches, nrhs):
+        yield SampleInput(make_a(*batch, n, n), args=(make_b(batch + (n,) + rhs),))
+
+
+def sample_inputs_linalg_solve_triangular(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    make_arg = partial(make_tensor, dtype=dtype, device=device)
+    bs = (1, 2, 0)
+    ns = (3, 0)
+    ks = (1, 3, 0)
+
+    for b, n, k, (left, upper, uni) in product(
+        bs, ns, ks, product((True, False), repeat=3)
+    ):
+        if b == 1:
+            A = make_arg((n, n)) if left else make_arg((k, k))
+            B = make_arg((n, k))
+        else:
+            A = make_arg((b, n, n)) if left else make_arg((b, k, k))
+            B = make_arg((b, n, k))
+        if uni:
+            # Not really necessary, but writing it for consistency
+            A.diagonal(0, -2, -1).fill_(1.0)
+        else:
+            d = A.diagonal(0, -2, -1)
+            d[d.abs() < 1e-6] = 1.0
+        if upper:
+            A.triu_()
+        else:
+            A.tril_()
+        kwargs = {"upper": upper, "left": left, "unitriangular": uni}
+        if requires_grad:
+            for grad_A, grad_B in product((True, False), repeat=2):
+                # Either A or B needs to have a gradient
+                if not grad_A and not grad_B:
+                    continue
+                yield SampleInput(
+                    A.clone().requires_grad_(grad_A),
+                    args=(B.clone().requires_grad_(grad_B),),
+                    kwargs=kwargs,
+                )
+        else:
+            yield SampleInput(A, args=(B,), kwargs=kwargs)
+
+
+def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates always solvable input for legacy solve functions
+    (the ones that are not in torch.linalg module).
+    The difference from sample_inputs_linalg_solve is that here the right-hand-side of A x = b equation
+    should have b.ndim >= 2, vectors are not allowed.
+    Also the arguments order is swapped.
+    """
+    out = sample_inputs_linalg_solve(
+        op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False
+    )
+
+    def out_fn(output):
+        return output[0]
+
+    # Reverses tensor order
+    for sample in out:
+        sample.input, sample.args = sample.args[0], (sample.input,)
+        if op_info.name == "solve":
+            sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwargs):
+    full_rank = op_info.name == "linalg.lu_factor"
+    make_fn = (
+        make_tensor
+        if not full_rank
+        else make_fullrank_matrices_with_distinct_singular_values
+    )
+    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    def out_fn(output):
+        if op_info.name == "linalg.lu":
+            return output[1], output[2]
+        else:
+            return output
+
+    batch_shapes = ((), (3,), (3, 3), (0,))
+    # pivot=False only supported in CUDA
+    pivots = (True, False) if torch.device(device).type == "cuda" else (True,)
+    deltas = (-2, -1, 0, +1, +2)
+    for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas):
+        shape = batch_shape + (S + delta, S)
+        # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple!
+        A = make_arg(shape) if not full_rank else make_arg(*shape)
+        yield SampleInput(A, kwargs={"pivot": pivot}, output_process_fn_grad=out_fn)
+
+
+def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 2, 0]
+
+    for batch, m, n in product(batches, ns, ns):
+        yield SampleInput(make_arg(batch + (m, n)))
+
+
+def sample_inputs_linalg_qr_geqrf(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    # QR is just well defined when the matrix is full rank
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 2, 0]
+
+    for batch, (m, n) in product(batches, product(ns, ns)):
+        shape = batch + (m, n)
+        yield SampleInput(make_arg(*shape))
+
+
+def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs):
+    a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
+    # Zero-dim tensors are not supported in NumPy, so we skip them for now.
+    # NumPy is used in reference check tests.
+    # See https://github.com/numpy/numpy/pull/20482 for tracking NumPy bugfix.
+    # a_shapes += [(0, 0, 1, 2, 3, 0)]
+    dimss = [None, (0, 2)]
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    for a_shape, dims in itertools.product(a_shapes, dimss):
+        a = make_arg(a_shape)
+        b = make_arg(a_shape[:2])
+        yield SampleInput(a, b, dims=dims)
+
+
+def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = make_fullrank_matrices_with_distinct_singular_values
+
+    def make_input():
+        return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # lhs / rhs shape can have any number of dimensions as long as their product equals 12
+    shapes = [
+        ((2, 2, 3), (12, 1)),
+        ((4, 3), (6, 1, 2)),
+    ]
+
+    for shape_lhs, shape_rhs in shapes:
+        inp = make_input().reshape(*shape_lhs, *shape_rhs).detach()
+        inp.requires_grad_(requires_grad)
+        yield SampleInput(inp, ind=len(shape_lhs))
+
+
+op_db: list[OpInfo] = [
+    OpInfo(
+        "linalg.cross",
+        ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim),
+        op=torch.linalg.cross,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        aten_name="linalg_cross",
+        sample_inputs_func=sample_inputs_cross,
+        error_inputs_func=error_inputs_cross,
+        supports_out=True,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.det",
+        aten_name="linalg_det",
+        op=torch.linalg.det,
+        aliases=("det",),
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+        check_batched_gradgrad=False,
+    ),
+    OpInfo(
+        "linalg.diagonal",
+        aten_name="linalg_diagonal",
+        aten_backward_name="diagonal_backward",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.float16, torch.chalf
+        ),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_diagonal_diag_embed,
+        error_inputs_func=error_inputs_diagonal_diag_embed,
+    ),
+    OpInfo(
+        "linalg.cholesky",
+        aten_name="linalg_cholesky",
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_cholesky,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.cholesky_ex",
+        aten_name="linalg_cholesky_ex",
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_cholesky,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.vecdot",
+        aten_name="linalg_vecdot",
+        ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_linalg_vecdot,
+        check_batched_forward_grad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.half: tol(atol=1.2e-2, rtol=1.7e-2)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.cond",
+        aten_name="linalg_cond",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_cond,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eig",
+        aten_name="linalg_eig",
+        op=torch.linalg.eig,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eig,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # AssertionError: Scalars are not equal!
+            DecorateInfo(
+                unittest.expectedFailure, "TestCommon", "test_out", device_type="cpu"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
+    ),
+    OpInfo(
+        "linalg.eigvals",
+        aten_name="linalg_eigvals",
+        op=torch.linalg.eigvals,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eigh",
+        aten_name="linalg_eigh",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eigh,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eigvalsh",
+        aten_name="linalg_eigvalsh",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eigh,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # Pre-existing condition; Needs to be fixed
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.householder_product",
+        aten_name="linalg_householder_product",
+        op=torch.linalg.householder_product,
+        aliases=("orgqr",),
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        # TODO: backward uses in-place operations that vmap doesn't like
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_householder_product,
+        decorators=[
+            skipCUDAIfNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)})
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped! Flaky"),
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cpu",
+                dtypes=(torch.complex128,),
+            ),
+            skipCUDAIfRocm,  # regression in ROCm 6.4
+        ],
+    ),
+    OpInfo(
+        "linalg.ldl_factor",
+        aten_name="linalg_ldl_factor",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_factor,
+        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.ldl_factor_ex",
+        aten_name="linalg_ldl_factor_ex",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_factor,
+        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.ldl_solve",
+        aten_name="linalg_ldl_solve",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_solve,
+        decorators=[
+            skipCUDAIfNoCusolver,
+            skipCUDAIfRocm,
+            skipCPUIfNoLapack,
+        ],
+    ),
+    OpInfo(
+        "linalg.lstsq",
+        aten_name="linalg_lstsq",
+        dtypes=floating_and_complex_types(),
+        supports_out=True,
+        sample_inputs_func=sample_inputs_linalg_lstsq,
+        error_inputs_func=error_inputs_lstsq,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # we skip gradient checks for this suite as they are tested in
+            # variant_test_name='grad_oriented'
+            DecorateInfo(unittest.skip("Skipped!"), "TestFwdGradients"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestBwdGradients"),
+            # The values for attribute 'shape' do not match
+            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lstsq",
+        aten_name="linalg_lstsq",
+        variant_test_name="grad_oriented",
+        # gradchecks for forward AD fails with full output tuple
+        # works when taking [:2], which is (solution, residuals)
+        op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[:2],
+        supports_out=False,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_lstsq,
+        error_inputs_func=error_inputs_lstsq_grad_oriented,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_autograd=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # tests do not work with passing lambda for op
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestOperatorSignatures",
+                "test_get_torch_func_signature_exhaustive",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_power",
+        aliases=("matrix_power",),
+        aten_name="linalg_matrix_power",
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_inplace_autograd=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=8e-5, rtol=2e-6)}),
+                "TestConsistency",
+                "test_output_grad_match",
+                device_type="mps",
+            ),
+        ),
+        sample_inputs_func=sample_inputs_linalg_matrix_power,
+    ),
+    OpInfo(
+        "linalg.multi_dot",
+        # Need this lambda because gradcheck does not work with TensorList inputs
+        aten_name="linalg_multi_dot",
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        supports_inplace_autograd=False,
+        # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407)
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # https://github.com/pytorch/pytorch/issues/66357
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_multi_dot,
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        skips=(
+            # https://github.com/pytorch/pytorch/issues/67470
+            DecorateInfo(
+                unittest.skip("67470!"), "TestCommon", "test_noncontiguous_samples"
+            ),
+            # Fails on XLA.
+            # AssertionError: False is not true : Tensors failed to compare as equal!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestOpInfo",
+                device_type="xla",
+                dtypes=(torch.long,),
+            ),
+            # https://github.com/pytorch/pytorch/issues/71774
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNNCOpInfo",
+                "test_nnc_correctness",
+                device_type="cpu",
+                dtypes=(torch.long,),
+            ),
+        ),
+    ),
+    # NB: linalg.norm has two variants so that different skips can be used for different sample inputs
+    OpInfo(
+        "linalg.norm",
+        aten_name="linalg_norm",
+        op=torch.linalg.norm,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=sample_inputs_linalg_norm,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.norm",
+        op=torch.linalg.norm,
+        variant_test_name="subgradients_at_zero",
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=partial(
+            sample_inputs_linalg_norm, variant="subgradient_at_zero"
+        ),
+        aten_name="linalg_norm",
+        supports_forward_ad=True,
+        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
+        # Could not allocate memory to change Tensor SizesAndStrides!
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # [NEW] Skips specifically for sample inputs at zero
+            # norm's vjp/jvp are not well-conditioned near zero
+            DecorateInfo(
+                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestFwdGradients", "test_forward_mode_AD"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestBwdGradients", "test_fn_grad"),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_norm",
+        aten_name="linalg_matrix_norm",
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        check_batched_gradgrad=False,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=sample_inputs_linalg_matrix_norm,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.qr",
+        aten_name="linalg_qr",
+        op=torch.linalg.qr,
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # In-place ops
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_linalg_qr_geqrf,
+        decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.slogdet",
+        aten_name="linalg_slogdet",
+        op=torch.linalg.slogdet,
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.vander",
+        aten_name="linalg_vander",
+        ref=np_vander_batched,
+        op=torch.linalg.vander,
+        dtypes=all_types_and_complex(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        sample_inputs_func=sample_inputs_linalg_vander,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    ReductionOpInfo(
+        "linalg.vector_norm",
+        op=torch.linalg.vector_norm,
+        identity=0,
+        nan_policy="propagate",
+        supports_multiple_dims=True,
+        complex_to_real=True,
+        supports_forward_ad=True,
+        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
+        # got: Could not allocate memory to change Tensor SizesAndStrides!
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        generate_args_kwargs=sample_kwargs_vector_norm,
+        aten_name="linalg_vector_norm",
+    ),
+    OpInfo(
+        "linalg.lu_factor",
+        aten_name="linalg_lu_factor",
+        op=torch.linalg.lu_factor,
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_compare_cpu",
+                active_if=(not TEST_XPU),
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu_factor_ex",
+        aten_name="linalg_lu_factor_ex",
+        op=torch.linalg.lu_factor_ex,
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_compare_cpu",
+                active_if=(not TEST_XPU),
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu",
+        aten_name="linalg_lu",
+        op=torch.linalg.lu,
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_compare_cpu",
+                active_if=(not TEST_XPU),
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu_solve",
+        op=torch.linalg.lu_solve,
+        aten_name="linalg_lu_solve",
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_lu_solve,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Tests different backward paths"),
+                "TestCommon",
+                "test_floating_inputs_are_differentiable",
+            ),
+        ),
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+    ),
+    OpInfo(
+        "linalg.inv",
+        aten_name="linalg_inv",
+        op=torch.linalg.inv,
+        aliases=("inverse",),
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.inv_ex",
+        aten_name="linalg_inv_ex",
+        op=torch.linalg.inv_ex,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve",
+        aten_name="linalg_solve",
+        op=torch.linalg.solve,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve_ex",
+        aten_name="linalg_solve_ex",
+        op=torch.linalg.solve_ex,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve_triangular",
+        aten_name="linalg_solve_triangular",
+        op=torch.linalg.solve_triangular,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve_triangular,
+        supports_fwgrad_bwgrad=True,
+        skips=(skipCPUIfNoLapack,),
+        # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
+        supports_forward_ad=True,
+    ),
+    OpInfo(
+        "linalg.matrix_rank",
+        aten_name="linalg_matrix_rank",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_matrix_rank,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            # jit doesn't accept tensor inputs for matrix rank
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=[torch.complex64, torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_rank",
+        aten_name="linalg_matrix_rank",
+        variant_test_name="hermitian",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        op=torch.linalg.pinv,
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_pinv,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # errors with "leaked XXXX bytes CUDA memory on device 0"
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        variant_test_name="singular",
+        # pinv is Frechet-differentiable in a rank-preserving neighborhood,
+        # so we feed inputs that are the products of two full-rank factors,
+        # to avoid any rank changes caused by the perturbations in the gradcheck
+        op=lambda a, b: torch.linalg.pinv(a @ b.mT),
+        dtypes=floating_and_complex_types(),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_pinv_singular,
+        # Only large tensors show issues with implicit backward used prior to
+        # explicit backward implementation.
+        decorators=[slowTest, skipCUDAIfNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # CUDA runs out of memory
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cuda",
+                dtypes=[torch.cdouble],
+            ),
+            # This test takes almost 2 hours to run!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestBwdGradients",
+                "test_fn_gradgrad",
+                device_type="cuda",
+                dtypes=[torch.cdouble],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        variant_test_name="hermitian",
+        dtypes=floating_and_complex_types(),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cuda",
+            ),
+            # This test is flaky under slow gradcheck, likely due to rounding issues
+            DecorateInfo(
+                skipIfSlowGradcheckEnv,
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.svd",
+        op=torch.linalg.svd,
+        aten_name="linalg_svd",
+        decomp_aten_name="_linalg_svd",
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        # We're using at::allclose, which does not have a batching rule
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_svd,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.svdvals",
+        op=torch.linalg.svdvals,
+        aten_name="linalg_svdvals",
+        decomp_aten_name="_linalg_svd",
+        dtypes=floating_and_complex_types(),
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        # We're using at::allclose, which does not have a batching rule
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_linalg_svdvals,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.tensorinv",
+        ref=np.linalg.tensorinv,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_tensorinv,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.tensorsolve",
+        ref=lambda a, b, dims=None: np.linalg.tensorsolve(a, b, axes=dims),
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_tensorsolve,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=8e-04, rtol=7e-06)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=2e-04, rtol=3e-06)}),
+                "TestConsistency",
+                "test_output_match",
+                device_type="mps",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    #
+    # torch.linalg
+    #
+    PythonRefInfo(
+        "_refs.linalg.cross",
+        torch_opinfo_name="linalg.cross",
+        supports_out=True,
+        op_db=op_db,
+        skips=(
+            # TODO: is this really needed?
+            DecorateInfo(
+                unittest.expectedFailure, "TestCommon", "test_python_ref_errors"
+            ),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.linalg.diagonal",
+        torch_opinfo_name="linalg.diagonal",
+        supports_out=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.vecdot",
+        torch_opinfo_name="linalg.vecdot",
+        op_db=op_db,
+    ),
+    ReductionPythonRefInfo(
+        "_refs.linalg.vector_norm",
+        torch_opinfo_name="linalg.vector_norm",
+        supports_out=True,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.matrix_norm",
+        torch_opinfo_name="linalg.matrix_norm",
+        supports_out=True,
+        # Uses vector_norm inside and vector_norm is affected by
+        # https://github.com/pytorch/pytorch/issues/77216
+        validate_view_consistency=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.norm",
+        torch_opinfo_name="linalg.norm",
+        supports_out=True,
+        # Uses vector_norm inside and vector_norm is affected by
+        # https://github.com/pytorch/pytorch/issues/77216
+        validate_view_consistency=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.svd",
+        torch_opinfo_name="linalg.svd",
+        supports_out=True,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.svdvals",
+        torch_opinfo_name="linalg.svdvals",
+        supports_out=True,
+        op_db=op_db,
+    ),
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f58ad2d7fb890346622a68f7b743f06f4c0f894
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py
@@ -0,0 +1,1594 @@
+# mypy: ignore-errors
+
+import math
+from copy import copy
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional
+
+import torch
+from torch.fx.experimental.symbolic_shapes import is_nested_int
+from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.opinfo.core import (
+    BinaryUfuncInfo,
+    ReductionOpInfo,
+    SampleInput,
+    UnaryUfuncInfo,
+)
+from torch.utils._pytree import tree_flatten, tree_map
+
+
+@dataclass
+class ExtraOpData:
+    """
+    Contains info on top of the typical OpInfo data that is useful for NJT test generation.
+
+    The process that converts the standard op_db -> an NJT-compatible op_db will attach this
+    data onto each associated OpInfo entry.
+    """
+
+    # Indicates whether the associated op is a view op
+    is_view: bool = False
+
+    # Specifies the names of any dim-related args that the op takes in. This is useful
+    # for NJT tests because there is often asymmetry across the supported set of dims for
+    # an op; it may make sense to operate over the batch dim but not the ragged dim, for
+    # example. The length of this list should match the number of relevant overloads.
+    # Each list item of the outer list should specify dim argnames. Ellipses should be used
+    # to indicate multi-dim support for a given overload.
+    #
+    # For example, squeeze() has both a dim and multi-dim overload, where the argname for
+    # each is simply "dim". Its entry should be: [["dim"], ["dim..."]].
+    #
+    # If no overload of the op accepts dim-related args, this should be None.
+    dim_args: list[list[str]] = None
+
+    # Helper function to extract names of dim-related args.
+    # Returns: tuple of (single dim argname if available, dim list argname if available)
+    # If the op doesn't support dim-related args at all OR this op only has overloads
+    # with multiple dim args (e.g. transpose()), then this returns (None, None).
+    def get_dim_argnames(self) -> tuple[Optional[str], Optional[str]]:
+        if self.dim_args is None:
+            return (None, None)
+
+        # name for the dim arg that supports a single dim
+        single_dim_argname = None
+        # name for the dim arg that supports a list of dims
+        dimlist_argname = None
+        for overload in self.dim_args:
+            # only consider overloads with a single dim-related arg
+            if len(overload) != 1:
+                continue
+            if overload[0].endswith("..."):
+                dimlist_argname = overload[0].replace("...", "")
+                if single_dim_argname is None:
+                    single_dim_argname = dimlist_argname
+            else:
+                single_dim_argname = overload[0]
+        return (single_dim_argname, dimlist_argname)
+
+
+# Mapping of OpInfo full names -> extra data to tack onto the OpInfo entry for use
+# in test generation.
+extra_op_data = {
+    "_segment_reduce.lengths": ExtraOpData(dim_args=[["axis0"]]),
+    "_segment_reduce.offsets": ExtraOpData(dim_args=[["axis0"]]),
+    "all": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "argmax": ExtraOpData(dim_args=[["dim"]]),
+    "argmin": ExtraOpData(dim_args=[["dim"]]),
+    "amax": ExtraOpData(dim_args=[["dim..."]]),
+    "amin": ExtraOpData(dim_args=[["dim..."]]),
+    "any": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "argsort": ExtraOpData(dim_args=[["dim"]]),
+    "broadcast_to": ExtraOpData(is_view=True),
+    "cat": ExtraOpData(dim_args=[["dim"]]),
+    "chunk": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "conj": ExtraOpData(is_view=True),
+    "contiguous": ExtraOpData(is_view=True),
+    "count_nonzero": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "cummax": ExtraOpData(dim_args=[["dim"]]),
+    "cummin": ExtraOpData(dim_args=[["dim"]]),
+    "cumprod": ExtraOpData(dim_args=[["dim"]]),
+    "cumsum": ExtraOpData(dim_args=[["dim"]]),
+    "cumulative_trapezoid": ExtraOpData(dim_args=[["dim"]]),
+    "diag_embed": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
+    "diagonal_copy": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diagonal_scatter": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diff": ExtraOpData(dim_args=[["dim"]]),
+    "expand": ExtraOpData(is_view=True),
+    "expand_as": ExtraOpData(is_view=True),
+    "fft.fft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.hfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.ifft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.ihfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.irfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.rfft": ExtraOpData(dim_args=[["dim"]]),
+    "flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]),
+    "flip": ExtraOpData(dim_args=[["dims..."]]),
+    "gather": ExtraOpData(dim_args=[["dim"]]),
+    "hash_tensor": ExtraOpData(dim_args=[["dim..."]]),
+    "imag": ExtraOpData(is_view=True),
+    "index_add": ExtraOpData(dim_args=[["dim"]]),
+    "index_copy": ExtraOpData(dim_args=[["dim"]]),
+    "index_fill": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
+    "index_select": ExtraOpData(dim_args=[["dim"]]),
+    "kthvalue": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.cross": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
+    "linalg.tensorsolve": ExtraOpData(dim_args=[["dims..."]]),
+    "linalg.vecdot": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.vector_norm": ExtraOpData(dim_args=[["dim..."]]),
+    "log_softmax": ExtraOpData(dim_args=[["dim"]]),
+    "logcumsumexp": ExtraOpData(dim_args=[["dim"]]),
+    "masked.amax": ExtraOpData(dim_args=[["dim"]]),
+    "masked.amin": ExtraOpData(dim_args=[["dim"]]),
+    "masked.argmax": ExtraOpData(dim_args=[["dim"]]),
+    "masked.argmin": ExtraOpData(dim_args=[["dim"]]),
+    "masked.logsumexp": ExtraOpData(dim_args=[["dim"]]),
+    "masked.mean": ExtraOpData(dim_args=[["dim"]]),
+    "masked.norm": ExtraOpData(dim_args=[["dim"]]),
+    "masked.prod": ExtraOpData(dim_args=[["dim"]]),
+    "masked.std": ExtraOpData(dim_args=[["dim"]]),
+    "masked.sum": ExtraOpData(dim_args=[["dim"]]),
+    "masked.var": ExtraOpData(dim_args=[["dim"]]),
+    "max.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
+    "median": ExtraOpData(dim_args=[["dim"]]),
+    "mean": ExtraOpData(dim_args=[["dim..."]]),
+    "min.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
+    "mode": ExtraOpData(dim_args=[["dim"]]),
+    "movedim": ExtraOpData(
+        dim_args=[["source", "destination"], ["source...", "destination..."]]
+    ),
+    "nanmean": ExtraOpData(dim_args=[["dim..."]]),
+    "nanmedian": ExtraOpData(dim_args=[["dim"]]),
+    "nansum": ExtraOpData(dim_args=[["dim..."]]),
+    "narrow": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "narrow_copy": ExtraOpData(dim_args=[["dim"]]),
+    "nn.functional.cosine_similarity": ExtraOpData(dim_args=[["dim"]]),
+    "nn.functional.glu": ExtraOpData(dim_args=[["dim"]]),
+    "permute": ExtraOpData(is_view=True, dim_args=[["dims..."]]),
+    "positive": ExtraOpData(is_view=True),
+    "prod": ExtraOpData(dim_args=[["dim"]]),
+    "ravel": ExtraOpData(is_view=True),
+    "real": ExtraOpData(is_view=True),
+    "renorm": ExtraOpData(dim_args=[["dim"]]),
+    "reshape": ExtraOpData(is_view=True),
+    "reshape_as": ExtraOpData(is_view=True),
+    "roll": ExtraOpData(dim_args=[["dims..."]]),
+    "rot90": ExtraOpData(dim_args=[["dims..."]]),
+    "scatter": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_add": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.sum": ExtraOpData(dim_args=[["dim"]]),
+    "select": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "select_scatter": ExtraOpData(dim_args=[["dim"]]),
+    "slice": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "slice_scatter": ExtraOpData(dim_args=[["dim"]]),
+    "softmax": ExtraOpData(dim_args=[["dim"]]),
+    "sort": ExtraOpData(dim_args=[["dim"]]),
+    "split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "split_with_sizes": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "split_with_sizes_copy": ExtraOpData(dim_args=[["dim"]]),
+    "squeeze": ExtraOpData(is_view=True, dim_args=[["dim"], ["dim..."]]),
+    "squeeze_copy": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "stack": ExtraOpData(dim_args=[["dim"]]),
+    "std": ExtraOpData(dim_args=[["dim..."]]),
+    "std.unbiased": ExtraOpData(dim_args=[["dim..."]]),
+    "sum": ExtraOpData(dim_args=[["dim..."]]),
+    "t": ExtraOpData(is_view=True),
+    "tensor_split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "tensordot": ExtraOpData(dim_args=[["dims..."]]),
+    "tile": ExtraOpData(dim_args=[["dims..."]]),
+    "topk": ExtraOpData(dim_args=[["dim"]]),
+    "transpose": ExtraOpData(is_view=True, dim_args=[["dim0", "dim1"]]),
+    "transpose_copy": ExtraOpData(dim_args=[["dim0", "dim1"]]),
+    "trapezoid": ExtraOpData(dim_args=[["dim"]]),
+    "trapz": ExtraOpData(dim_args=[["dim"]]),
+    "unbind": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unflatten": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unfold": ExtraOpData(is_view=True, dim_args=[["dimension"]]),
+    "unfold_copy": ExtraOpData(dim_args=[["dimension"]]),
+    "unsafe_chunk": ExtraOpData(dim_args=[["dim"]]),
+    "unsafe_split": ExtraOpData(dim_args=[["dim"]]),
+    "unsqueeze": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unsqueeze_copy": ExtraOpData(dim_args=[["dim"]]),
+    "var": ExtraOpData(dim_args=[["dim..."]]),
+    "var.unbiased": ExtraOpData(dim_args=[["dim..."]]),
+    "view": ExtraOpData(is_view=True),
+    "view_as": ExtraOpData(is_view=True),
+    "view_as_complex": ExtraOpData(is_view=True),
+    "view_as_real": ExtraOpData(is_view=True),
+}
+
+
+# random integer used for sizes
+def _rnd():
+    return torch.randint(3, 8, ()).item()
+
+
+def _raggedness_matches(nt1, nt2):
+    return (
+        nt1.is_nested
+        and nt2.is_nested
+        and nt1._ragged_idx == nt2._ragged_idx
+        and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx]
+    )
+
+
+# Helper function to avoid reusing the exact same tensor / NJT across SampleInputs,
+# as this causes autograd problems.
+def _clone(t):
+    requires_grad = t.requires_grad
+    return t.detach().clone().requires_grad_(requires_grad)
+
+
+# Helper function to update a sample with new kwargs / name
+def _update_sample(sample, new_kwargs):
+    all_kwargs = dict(sample.kwargs)
+    all_kwargs.update(new_kwargs)
+    full_name = ", ".join([sample.name, *(f"{k}={v}" for (k, v) in new_kwargs.items())])
+    return SampleInput(
+        _clone(sample.input),
+        args=sample.args,
+        kwargs=all_kwargs,
+        name=full_name,
+    )
+
+
+# Generates a random NT.
+# dims should be something like [5, None, 10], with None indicating that a
+# random ragged structure should be used
+def random_nt_from_dims(
+    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
+):
+    sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])]
+    return torch.nested.nested_tensor(
+        [torch.randn(*size) for size in sizes],
+        device=device,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+    )
+
+
+# Helper function to get a reasonable string representation of an NJT for use in
+# SampleInput names.
+def _describe_njt(njt) -> str:
+    contig_type = "_contig" if njt.is_contiguous() else "_noncontig"
+    if njt._lengths is not None and njt._offsets is not None:
+        contig_type += "_holes"
+    elif njt._ragged_idx != 1:
+        contig_type += "_transposed"
+
+    cached_data = "_without_seqlen_cache"
+    if njt._max_seqlen_tensor is not None:
+        cached_data = "_with_seqlen_cache"
+
+    return f"{njt.dim()}D{contig_type}{cached_data}"
+
+
+# Helper function to get a reasonable string representation of a given dim wrt an NJT.
+def _describe_dim(njt, dim):
+    if dim == 0:
+        return "batch_dim"
+    elif dim == njt._ragged_idx:
+        return "ragged_dim"
+    return "normal_dim"
+
+
+# Helper function for generating a comprehensive set of NJT sample inputs.
+def _sample_njts(device, dtype, requires_grad=False, dims=None):
+    if dims is None:
+        dims = [2, 3, 4]
+    if not isinstance(dims, (list, tuple)):
+        dims = [dims]
+
+    # contiguous NJTs
+    for dim in dims:
+        # with min / max seqlen cached
+        shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)])
+        nt = random_nt_from_dims(
+            shape,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            layout=torch.jagged,
+        )
+        yield nt
+
+        # without min / max seqlen cached
+        values = _clone(nt.values())
+        offsets = _clone(nt.offsets())
+        yield torch.nested.nested_tensor_from_jagged(values, offsets).requires_grad_(
+            requires_grad
+        )
+
+        # non-contiguous transposed NJT (not possible for 2D)
+        if dim > 2:
+            yield nt.transpose(-1, nt._ragged_idx)
+
+        # non-contiguous with holes NJT
+        values = _clone(nt.values())
+        offsets = _clone(nt.offsets())
+        # subtract 1 to cause holes
+        lengths = _clone(offsets.diff() - 1)
+        yield torch.nested.nested_tensor_from_jagged(
+            values=values,
+            offsets=offsets,
+            lengths=lengths,
+        ).requires_grad_(requires_grad)
+
+
+# Computes an unbind-based reference for a given OpInfo on a given SampleInput.
+# This reference unbinds the input NJT and invokes the op on each of the components,
+# optionally wrapping the result in an NJT.
+def unbind_reference(op, sample, wrap_output_as_njt=True):
+    # first NJT in the arglist determines expected ragged structure
+    nt_inp = (
+        sample.input
+        if sample.input.is_nested
+        # TODO: look in kwargs too?
+        else next(a for a in sample.args if a.is_nested)
+    )
+
+    out_ref_components = []
+    for i in range(nt_inp.shape[0]):
+
+        def _slice_input(t, i=i, inp=nt_inp):
+            # any NJT with the same ragged structure as the input should
+            # be sliced to pass to the reference
+            if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp):
+                return t[i]
+            # allow the SampleInput to tell us how to slice it for ref calculation
+            elif isinstance(t, torch.Tensor) and hasattr(t, "_batch_dim"):
+                bdim = t._batch_dim  # type: ignore[attr]
+                if t.shape[bdim] == 1:
+                    return t[0]
+                else:
+                    return t.select(bdim, i)
+            else:
+                return t
+
+        inp = _slice_input(sample.input)
+        args = tree_map(_slice_input, sample.args)
+        kwargs = tree_map(_slice_input, sample.kwargs)
+
+        # Handle indices in index_put
+        if "index_put" in op.full_name and "indices" in kwargs:
+            if len(kwargs["indices"]) > 1:
+                # If after unrolling we still have indices left, use them
+                kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]]
+            else:
+                # If no indices are left, create them so they match the NJT implementation
+                sequence_put = kwargs["indices"][0].tolist()
+                if i in sequence_put:
+                    kwargs["indices"] = [
+                        torch.tensor(
+                            list(range(inp.shape[0])),
+                            dtype=torch.int32,
+                            device=kwargs["indices"][0].device,
+                        )
+                    ]
+                else:
+                    kwargs["indices"] = [
+                        torch.tensor(
+                            [], dtype=torch.int32, device=kwargs["indices"][0].device
+                        )
+                    ]
+
+        from torch.nested._internal.ops import _outer_to_inner_dim
+
+        # Need to adjust dims to apply on NJT component
+        if op._extra_op_data.dim_args is not None:
+            # get all possible dim-related argnames that could be encountered for this op
+            argnames = tree_map(
+                lambda a: a.replace("...", ""),
+                tree_flatten(op._extra_op_data.dim_args)[0],
+            )
+            # for all dim-related args present, convert from outer -> inner dim space
+            for argname in {a for a in argnames if a in kwargs}:
+                # allow the SampleInput to tell us how to canonicalize the dim kwargs
+                ndim = nt_inp._ndim if hasattr(nt_inp, "_ndim") else nt_inp.dim()
+                kwargs[argname] = _outer_to_inner_dim(
+                    ndim, kwargs[argname], nt_inp._ragged_idx, canonicalize=True
+                )
+
+        out_ref_component = op.op(inp, *args, **kwargs)
+        out_ref_components.append(out_ref_component)
+
+    if wrap_output_as_njt:
+        # handle list / tuple of outputs
+        if len(out_ref_components) > 0 and isinstance(
+            out_ref_components[0], (list, tuple)
+        ):
+            num_returns = len(out_ref_components[0])
+            # ensure we get the same number of returns for each invocation
+            assert all(len(o) == num_returns for o in out_ref_components)
+            # construct NJTs from same index returns from each invocation
+            njt_returns = [
+                torch.nested.as_nested_tensor(
+                    [o[r] for o in out_ref_components], layout=torch.jagged
+                )
+                for r in range(num_returns)
+            ]
+            return type(out_ref_components[0])(njt_returns)
+        return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged)
+
+    return out_ref_components
+
+
+# Computes the reference value for a non-reduction unary op with dim-wise application.
+def unary_dimwise_reference(op, sample, batchwise_reference=None):
+    # extract info about the dim args this op supports
+    assert op._extra_op_data.dim_args is not None
+    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
+    # only support a single non-list dim arg for now
+    assert dimlist_argname is None
+    assert single_dim_argname is not None
+    if sample.kwargs[single_dim_argname] == 0:
+        # unbind reference won't work for batch-wise operation; handle this case here
+        assert batchwise_reference is not None
+        return batchwise_reference(op, sample)
+    return unbind_reference(op, sample)
+
+
+# Computes the reference value for a reduction op.
+def reduction_reference(op, sample):
+    assert sample.input.is_nested
+
+    # extract info about the dim args this op supports
+    assert op._extra_op_data.dim_args is not None
+    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+
+    dim = sample.kwargs.get(
+        dimlist_argname, sample.kwargs.get(single_dim_argname, None)
+    )
+    keepdim = sample.kwargs.get("keepdim", False)
+    assert dim != 0, "reductions over just the batch dim are not supported"
+    if isinstance(dim, (tuple, list)):
+        reduce_on_ragged = sample.input._ragged_idx in dim
+        reduce_on_batch = 0 in dim
+    else:
+        reduce_on_ragged = sample.input._ragged_idx == dim
+        reduce_on_batch = dim == 0
+
+    if dim is None:
+        # calculate reference value by running reduction on values buffer
+        return op.op(sample.input.values(), *sample.args, **sample.kwargs)
+
+    if reduce_on_ragged and reduce_on_batch:
+        # run reference directly on buffer with dims converted to inner space
+        from torch.nested._internal.ops import _outer_to_inner_dim
+
+        ref_kwargs = dict(sample.kwargs)
+        assert dimlist_argname is not None
+        ref_kwargs[dimlist_argname] = _outer_to_inner_dim(
+            sample.input.dim(), dim, sample.input._ragged_idx, canonicalize=True
+        )
+        out = op.op(sample.input.values(), *sample.args, **ref_kwargs)
+        if keepdim:
+            if isinstance(out, (tuple, list)):
+                # some ops return multiple things; unsqueeze all of them
+                out = type(out)(o.unsqueeze(0) for o in out)
+            else:
+                out = out.unsqueeze(0)
+        return out
+
+    if reduce_on_ragged and not reduce_on_batch:
+        # calculate reference value by running an unbind reference and stacking
+        out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False)
+        if len(out_ref_components) > 0 and isinstance(
+            out_ref_components[0], (tuple, list)
+        ):
+            # some ops return multiple things; stack all of them
+            num_returns = len(out_ref_components[0])
+            # ensure we get the same number of returns for each invocation
+            assert all(len(o) == num_returns for o in out_ref_components)
+            # stack same index returns from each invocation
+            stacked_returns = [
+                torch.stack([o[r] for o in out_ref_components], dim=0)
+                for r in range(num_returns)
+            ]
+            return type(out_ref_components[0])(stacked_returns)
+        return torch.stack(out_ref_components, dim=0)
+
+    # unbind reference works for other reductions
+    return unbind_reference(op, sample)
+
+
+def sample_inputs_elementwise_njt_unary(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        yield SampleInput(njt, kwargs=dict(op_kwargs), name=_describe_njt(njt))
+
+
+def sample_inputs_elementwise_njt_binary(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    for njt1 in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        njt_desc = _describe_njt(njt1)
+        njt2 = torch.randn_like(njt1)
+        yield SampleInput(
+            _clone(njt1),
+            args=(njt2,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, NT)",
+        )
+
+        # broadcasting case: (B, j0, ...) with (B, 1, ...)
+        dense_shape = list(njt1.shape)
+        dense_shape[njt1._ragged_idx] = 1
+        t = torch.randn(
+            dense_shape,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        t2 = _clone(t)
+        # used for slicing in unbind_reference()
+        t._batch_dim = 0
+        t2._batch_dim = 0
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(t,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting 1 over ragged",
+        )
+        # (T, NT)
+        yield SampleInput(
+            t2,
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting 1 over ragged",
+        )
+
+        # broadcasting case: (B, j0, ...) with (1, 1...)
+        t = torch.randn(
+            [1 for _ in range(njt1.dim())],
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        t2 = _clone(t)
+        # used for slicing in unbind_reference()
+        t._batch_dim = 0
+        t2._batch_dim = 0
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(t,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting all 1s",
+        )
+        # (T, NT)
+        yield SampleInput(
+            t2,
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting all 1s",
+        )
+
+        # broadcasting case: (B, j0, ...) with (...)
+        if njt1.dim() > njt1._ragged_idx + 1:
+            t = torch.randn(
+                njt1.shape[njt1._ragged_idx + 1 :],
+                device=device,
+                dtype=dtype,
+                requires_grad=requires_grad,
+            )
+            # (NT, T)
+            yield SampleInput(
+                _clone(njt1),
+                args=(_clone(t),),
+                kwargs=dict(op_kwargs),
+                name=f"{njt_desc}: (NT, T) broadcasting normal dims",
+            )
+            # (T, NT)
+            yield SampleInput(
+                _clone(t),
+                args=(_clone(njt1),),
+                kwargs=dict(op_kwargs),
+                name=f"{njt_desc}: (T, NT) broadcasting normal dims",
+            )
+
+        # broadcasting case: (B, j0, ...) with scalar
+        t = torch.randn((), device=device, dtype=dtype, requires_grad=requires_grad)
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(_clone(t),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting with scalar",
+        )
+        # (T, NT)
+        yield SampleInput(
+            _clone(t),
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting with scalar",
+        )
+
+    # mixed broadcasting case: (B, j0, 1) with (B, 1, D)
+    B = 4
+    D = 16
+    njt = random_nt_from_dims(
+        (B, None, 1),
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        layout=torch.jagged,
+    )
+    njt_desc = _describe_njt(njt)
+    t = torch.randn(B, 1, D, device=device, dtype=dtype, requires_grad=requires_grad)
+    t2 = _clone(t)
+    # used for slicing in unbind_reference()
+    t._batch_dim = 0
+    t2._batch_dim = 0
+
+    # (NT, T)
+    yield SampleInput(
+        _clone(njt),
+        args=(t,),
+        kwargs=dict(op_kwargs),
+        name=f"{njt_desc}: (NT, T) mixed broadcasting",
+    )
+    # (T, NT)
+    yield SampleInput(
+        t2,
+        args=(_clone(njt),),
+        kwargs=dict(op_kwargs),
+        name=f"{njt_desc}: (T, NT) mixed broadcasting",
+    )
+
+
+def sample_inputs_njt_reduction(
+    op_info,
+    device,
+    dtype,
+    requires_grad,
+    supports_keepdim=True,
+    op_kwargs=None,
+    **kwargs,
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    # extract info about the dim args this op supports
+    assert op_info._extra_op_data.dim_args is not None
+    (
+        single_dim_argname,
+        dimlist_argname,
+    ) = op_info._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+    supports_dimlist = dimlist_argname is not None
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        njt_desc = _describe_njt(njt)
+        keepdim_values = [False, True] if supports_keepdim else [None]
+        for keepdim in keepdim_values:
+            keepdim_suffix = f" with keepdim={keepdim}" if supports_keepdim else ""
+            # single dim-wise reduction; includes reduction over the ragged dim
+            # NB: reduction over the batch dim is not supported!
+            # TODO: Cover this in the set of error inputs
+            for dim in range(1, njt.dim()):
+                dim_desc = "normal" if dim != njt._ragged_idx else "ragged"
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        single_dim_argname: dim,
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: {dim_desc} dim reduction{keepdim_suffix}",
+                )
+
+            if supports_dimlist:
+                # reduce on both batch and ragged dims
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        dimlist_argname: [0, njt._ragged_idx],
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: batch+ragged reduction{keepdim_suffix}",
+                )
+
+                # reduce on batch, ragged, and other dims
+                for other_dim in range(njt._ragged_idx + 1, njt.dim()):
+                    yield SampleInput(
+                        _clone(njt),
+                        kwargs={
+                            **op_kwargs,
+                            dimlist_argname: [0, njt._ragged_idx, other_dim],
+                            **({"keepdim": keepdim} if supports_keepdim else {}),
+                        },
+                        name=(
+                            f"{njt_desc}: batch+ragged+dim={other_dim} "
+                            f"reduction{keepdim_suffix}"
+                        ),
+                    )
+
+                # reduce on two non-ragged, non-batch dims
+                if njt.dim() > 3 and njt._ragged_idx == 1:
+                    yield SampleInput(
+                        _clone(njt),
+                        kwargs={
+                            **op_kwargs,
+                            dimlist_argname: [njt.dim() - 2, njt.dim() - 1],
+                            **({"keepdim": keepdim} if supports_keepdim else {}),
+                        },
+                        name=f"{njt_desc}: two normal dim reduction{keepdim_suffix}",
+                    )
+
+                # full reduction by specifying all dims
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        dimlist_argname: list(range(njt.dim())),
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: all dim reduction{keepdim_suffix}",
+                )
+
+                # TODO: Reducing on ragged dim and non-batch dim is not supported;
+                # cover this in the set of error inputs.
+
+        # full reduction
+        yield SampleInput(
+            _clone(njt),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: full reduction with keepdim={keepdim}",
+        )
+
+
+def unsupported_sample_inputs_func(op_name):
+    def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs):
+        raise RuntimeError(
+            f"OpInfo for {op_name} does not support NJT. Support can be added by modifying "
+            "torch/testing/_internal/opinfo/definitions/nested.py."
+        )
+
+    return _f
+
+
+def unsupported_reference(op_name):
+    def _f(op, sample):
+        raise RuntimeError(
+            f"OpInfo for {op_name} does not define a ref() function. Support can be added by "
+            "modifying torch/testing/_internal/opinfo/definitions/nested.py."
+        )
+
+    return _f
+
+
+# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
+def sample_inputs_unary_dimwise(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if op_kwargs is None:
+        op_kwargs = {}
+
+    # only support a single non-list dim arg for now
+    assert op_info._extra_op_data is not None
+    single_dim_argname, dimlist_argname = op_info._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+    assert dimlist_argname is None
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        for dim in range(njt.dim()):
+            kwargs = {single_dim_argname: dim}
+            kwargs.update(op_kwargs)
+            yield SampleInput(
+                _clone(njt),
+                kwargs=kwargs,
+                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
+            )
+
+
+def batchwise_reference_chunk(op, sample):
+    # reference for chunk() over dim=0
+    B = sample.input.size(0)
+    num_chunks = sample.kwargs["chunks"]
+    chunk_size = math.ceil(B / num_chunks)
+    num_full_chunks = B // chunk_size
+    chunk_sizes = [chunk_size for _ in range(num_full_chunks)]
+    if B % chunk_size != 0:
+        # final chunk contains the leftovers
+        chunk_sizes.append(B % chunk_size)
+
+    # split unbound components into chunks according to calculated sizes
+    components = list(sample.input.unbind())
+    start = 0
+    chunks = []
+    for chunk_size in chunk_sizes:
+        chunks.append(components[start : start + chunk_size])
+        start += chunk_size
+
+    # rejoin into NJT outputs
+    return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks]
+
+
+def batchwise_reference_narrow(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_select(op, sample):
+    # reference for select() over dim=0
+    return sample.input.unbind()[sample.kwargs["index"]]
+
+
+def batchwise_reference_split(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_split_with_sizes(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_unflatten(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_unsqueeze(op, sample):
+    raise ValueError("unsqueeze() is not intended to operate on the batch dim")
+
+
+def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs):
+    # non-contiguous NJTs
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        yield SampleInput(njt, name=_describe_njt(njt))
+
+    for memory_format in (torch.contiguous_format, torch.preserve_format):
+        # construct a "non-contiguous with holes" NJT
+        values = torch.randn(
+            10, 5, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+        offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64)
+        lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64)
+        njt = torch.nested.nested_tensor_from_jagged(
+            values, offsets=offsets, lengths=lengths
+        )
+
+        njt_desc = _describe_njt(njt)
+        yield SampleInput(
+            njt,
+            kwargs={"memory_format": memory_format},
+            name=f"{njt_desc}: {memory_format})",
+        )
+
+
+def sample_inputs_fill(op_info, device, dtype, requires_grad, **kwargs):
+    # scalar case
+    unary_func = partial(sample_inputs_elementwise_njt_unary, op_kwargs={"value": 42.0})
+    yield from unary_func(op_info, device, dtype, requires_grad)
+
+    # TODO: add Tensor case
+
+
+def sample_inputs_mvl_gamma(p):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p})
+
+
+def sample_inputs_polygamma_n(n):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
+
+
+def sample_inputs_special_polygamma_n(n):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
+
+
+def sample_inputs_to(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
+    for njt in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[2, 3, 4],
+    ):
+        other_dtypes = (
+            d for d in (torch.float32, torch.half, torch.double) if d is not dtype
+        )
+        for other_dtype in other_dtypes:
+            sample_name = f"{njt.dim()}D: {dtype} -> {other_dtype}"
+            yield SampleInput(_clone(njt), kwargs={"dtype": dtype}, name=sample_name)
+
+        # only include device transfer for CUDA inputs
+        if "cuda" in device:
+            other_device = "cpu"
+            sample_name = f"{_describe_njt(njt)}: {device} -> {other_device}"
+            yield SampleInput(
+                _clone(njt), kwargs={"device": other_device}, name=sample_name
+            )
+
+
+def sample_inputs_bmm(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
+    for njt_3d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
+    ):
+        # (B, j1, D) x (B, D, E) => (B, j1, E)
+        if njt_3d._ragged_idx == 1:
+            B, D = njt_3d.shape[0], njt_3d.shape[-1]
+            E = D + 2
+            other = torch.randn(B, D, E, device=device, dtype=dtype)
+            # used for slicing in unbind_reference()
+            other._batch_dim = 0
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"mat2": other},
+                name=f"{njt_desc}: (B, j, D) x (B, D, E)",
+            )
+
+        # TODO (need factory functions):
+        # (B, D, j1) x (B, j1, E) => (B, D, E)
+
+
+def reference_bmm(op, sample):
+    # unbind reduces a dim and bmm requires 3D, so use matmul as the reference
+    matmul_op = copy(op)
+    matmul_op.op = torch.matmul
+    # change arg name from mat2 -> other
+    modified_sample = copy(sample)
+    other = modified_sample.kwargs["mat2"]
+    del modified_sample.kwargs["mat2"]
+    modified_sample.kwargs["other"] = other
+    return unbind_reference(matmul_op, modified_sample)
+
+
+def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single chunks value
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"chunks": 3})
+        # other dim chunking: test different chunks values
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for chunks in [1, D // 2, D - 1, D]:
+                yield _update_sample(sample_input, {"chunks": chunks})
+
+
+def sample_inputs_matmul(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    # also run bmm samples through
+    for sample_input in sample_inputs_bmm(op_info, device, dtype, requires_grad):
+        # change arg name from mat2 -> other
+        other = sample_input.kwargs["mat2"]
+        del sample_input.kwargs["mat2"]
+        sample_input.kwargs["other"] = other
+        yield sample_input
+
+    # 3D cases not covered by bmm
+    for njt_3d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
+    ):
+        # (B, j1, D) x (D, E) => (B, j1, E)
+        if njt_3d._ragged_idx == 1:
+            D = njt_3d.shape[-1]
+            E = D + 2
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"other": torch.randn(D, E, device=device, dtype=dtype)},
+                name=f"{njt_desc}: (B, j, D) x (D, E)",
+            )
+
+    # 4D cases
+    for njt_4d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[4]
+    ):
+        # (B, j1, D, E) x (E, F) => (B, j1, D, F)
+        if njt_4d._ragged_idx == 1:
+            E = njt_4d.shape[-1]
+            F = E + 2
+            njt_desc = _describe_njt(njt_4d)
+            yield SampleInput(
+                _clone(njt_4d),
+                kwargs={"other": torch.randn(E, F, device=device, dtype=dtype)},
+                name=f"{njt_desc}: (B, j, D, E) x (E, F)",
+            )
+
+    # Dense x NJT cases
+    for njt_3d in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[3],
+    ):
+        # (B, F, E) x (B, E, j1) => (B, F, j1)
+        if njt_3d._ragged_idx == 2:
+            B = njt_3d.shape[0]
+            E = njt_3d.shape[1]
+            F = E + 2
+            njt_desc = _describe_njt(njt_3d)
+            dense_t = torch.randn(
+                B, F, E, device=device, dtype=dtype, requires_grad=requires_grad
+            )
+            dense_t._batch_dim = 0  # for unbind_reference()
+            yield SampleInput(
+                dense_t,
+                args=(_clone(njt_3d),),
+                name=f"{njt_desc}: (B, F, E) x (B, E, j1)",
+            )
+
+    # NJT x NJT => Dense case
+    for njt_3d in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[3],
+    ):
+        # (B, E, j1) x (B, j1, F) => (B, E, F)
+        if njt_3d._ragged_idx == 2 and njt_3d.is_contiguous():
+            B, E, _ = njt_3d.shape
+            sum_j1 = len(njt_3d.values())
+            other_cont = torch.randn(
+                sum_j1, E + 2, device=device, dtype=dtype, requires_grad=requires_grad
+            )
+            other_njt = torch.nested.nested_tensor_from_jagged(
+                other_cont, njt_3d.offsets(), lengths=njt_3d._lengths
+            )
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"other": _clone(other_njt)},
+                name=f"{njt_desc}: (B, E, j1) x (B, j1, F)",
+            )
+
+        # TODO (need factory functions):
+        # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F)
+
+
+def sample_inputs_masked_select(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2]
+    ):
+        yield SampleInput(
+            njt,
+            kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)},
+            name=_describe_njt(njt),
+        )
+
+
+def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim narrowing: test a single start, length value
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"start": 1, "length": 2})
+        # other dim narrowing: test different start, length values
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for start, length in [(0, D), (0, D - 1), (1, D - 1), (D - 1, 1)]:
+                yield _update_sample(sample_input, {"start": start, "length": length})
+
+
+def sample_inputs_nn_functional_embedding(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    indices = torch.nested.nested_tensor(
+        [
+            torch.tensor([0, 2, 1, 3]),
+            torch.tensor([4, 2, 1]),
+            torch.tensor([6, 7, 5, 2, 4]),
+        ],
+        layout=torch.jagged,
+        dtype=torch.int64,
+        device=device,
+    )
+
+    NUM_EMBEDDINGS = 20
+    EMBEDDING_DIM = 32
+    weight = torch.randn(NUM_EMBEDDINGS, EMBEDDING_DIM, device=device, dtype=dtype)
+
+    # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
+    # can be checked
+    yield SampleInput(
+        _clone(weight).requires_grad_(),
+        args=(indices,),
+    )
+
+    yield SampleInput(
+        _clone(weight).requires_grad_(),
+        args=(indices,),
+        kwargs={"padding_idx": 1},
+    )
+
+
+def sample_inputs_index_put(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        for dim in range(njt.dim()):
+            indices = [
+                torch.tensor(list(range(njt.size(0))), device=njt.device),
+                *[
+                    torch.tensor([0] * njt.size(0), device=njt.device)
+                    for _ in range(dim - 1)
+                ],
+            ]
+            njt_desc = _describe_njt(njt)
+            yield SampleInput(
+                _clone(njt),
+                kwargs={
+                    "indices": indices,
+                    "values": torch.tensor(1.0, device=njt.device),
+                },
+                name=f"{njt_desc}: up to dim {dim - 1}",
+            )
+
+    # Non-cont NJT for completeness
+    offsets = torch.tensor([0, 2, 5, 7], device=device)
+    lengths = torch.tensor([2, 2, 2], device=device)
+    indices = [
+        torch.tensor([0, 1, 2], device=device),
+        torch.tensor([0, 1, 1], device=device),
+        torch.tensor([0, 0, 0], device=device),
+    ]
+    a = torch.nested.nested_tensor_from_jagged(
+        torch.zeros(7, 3, device=device), offsets, lengths
+    ).requires_grad_(requires_grad)
+
+    njt_desc = _describe_njt(a)
+    yield SampleInput(
+        _clone(a),
+        kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)},
+        name=f"{njt_desc}: all dims",
+    )
+
+
+def sample_inputs_nn_functional_embedding_bag(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    for generate_per_sample_weight in (True, False):
+        for mode in ("sum", "mean", "max"):
+            # per_sample_weights is only supported for mode='sum'
+            if mode != "sum" and generate_per_sample_weight:
+                continue
+
+            NUM_EMBEDDINGS = 10
+            EMBEDDING_DIM = 32
+            weight = torch.randn(
+                NUM_EMBEDDINGS, EMBEDDING_DIM, dtype=dtype, device=device
+            )
+
+            njt = torch.nested.nested_tensor(
+                [
+                    torch.randint(0, NUM_EMBEDDINGS, size=(2,)),
+                    torch.randint(0, NUM_EMBEDDINGS, size=(3,)),
+                    torch.randint(0, NUM_EMBEDDINGS, size=(4,)),
+                ],
+                layout=torch.jagged,
+                dtype=torch.int64,
+                device=device,
+            )
+
+            per_sample_weights = None
+            if generate_per_sample_weight:
+                per_sample_weights = torch.randn_like(njt, dtype=dtype)
+
+            # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
+            # can be checked
+            yield SampleInput(
+                weight,
+                args=(njt,),
+                kwargs={
+                    "mode": mode,
+                    "per_sample_weights": per_sample_weights,
+                },
+            )
+
+
+def reference_nn_functional_embedding_bag(op, sample):
+    # run reference on a single bag at a time
+    new_kwargs = dict(sample.kwargs)
+    new_kwargs.update(
+        {"offsets": torch.tensor([0], dtype=torch.int64, device=sample.input.device)}
+    )
+    # flip input / weight back to what unbind_reference() expects
+    sample = SampleInput(sample.args[0], args=(sample.input,), kwargs=new_kwargs)
+    old_op = op.op
+    op.op = torch.nn.functional.embedding_bag
+    output = unbind_reference(op, sample, wrap_output_as_njt=False)
+    op.op = old_op
+    # concat bag outputs to get final output
+    return torch.cat(output, dim=0)
+
+
+def sample_inputs_nn_functional_linear(op_info, device, dtype, requires_grad, **kwargs):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4, 5]
+    ):
+        # projection over a ragged dim is not currently supported
+        if is_nested_int(njt.size(-1)):
+            continue
+
+        # with bias
+        NUM_OUTPUT = 10
+        weight = torch.randn(
+            NUM_OUTPUT,
+            njt.size(-1),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        bias = torch.randn(
+            NUM_OUTPUT, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+                "bias": _clone(bias),
+            },
+            name=f"{_describe_njt(njt)}: with bias",
+        )
+
+        # without bias
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+            },
+            name=f"{_describe_njt(njt)}: without bias",
+        )
+
+
+def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
+    ):
+        # Second dim is interpreted as number of channels; this should be non-ragged for now
+        num_channels = njt.size(1)
+        if is_nested_int(num_channels):
+            continue
+
+        # 1D weight
+        weight = torch.randn(
+            num_channels,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+            },
+            name=f"{_describe_njt(njt)}: 1D weight",
+        )
+
+        # scalar tensor weight
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": torch.tensor(4.2, device=device, dtype=dtype),
+            },
+            name=f"{_describe_njt(njt)}: scalar tensor weight",
+        )
+
+
+def sample_inputs_nn_functional_rms_norm(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
+    ):
+        # normalize over non-ragged dims
+        for start_dim in range(njt.dim()):
+            if start_dim <= njt._ragged_idx:
+                continue
+
+            normalized_shape = njt.shape[start_dim:]
+            weight = torch.randn(
+                normalized_shape,
+                device=device,
+                dtype=dtype,
+                requires_grad=requires_grad,
+            )
+
+            yield SampleInput(
+                _clone(njt),
+                kwargs={
+                    "normalized_shape": normalized_shape,
+                    "weight": weight,
+                },
+                name=f"{_describe_njt(njt)}",
+            )
+
+
+sample_inputs_nn_functional_threshold = partial(
+    sample_inputs_elementwise_njt_unary,
+    op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9},
+)
+
+
+def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single index
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"index": 0})
+        # other dim chunking: test different indices
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for index in [0, D // 2, D - 1]:
+                yield _update_sample(sample_input, {"index": index})
+
+
+def sample_inputs_split(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single split size
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"split_size_or_sections": 3})
+        # other dim chunking: test different split sizes
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for split_size in [1, D // 2, D - 1, D]:
+                yield _update_sample(
+                    sample_input, {"split_size_or_sections": split_size}
+                )
+
+
+def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # It will never make sense to operate on the ragged dim.
+        # TODO: Handle this with error_inputs
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            continue
+
+        D = sample_input.input.size(sample_input.kwargs["dim"])
+        # splits should add up to D
+        split1 = torch.randint(0, D - 1, size=()).item()
+        split2 = D - split1
+        yield _update_sample(sample_input, {"split_sizes": [split1, split2]})
+
+
+def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
+    # squeeze-specific NJT generator (need to ensure there are some 1s in the shape)
+    def _get_njts():
+        njt = random_nt_from_dims(
+            (4, None, 1, 3, 1),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            layout=torch.jagged,
+        )
+        yield njt
+        # without min / max seqlen cached
+        values = njt.values().detach().clone()
+        offsets = njt.offsets().detach().clone()
+        yield torch.nested.nested_tensor_from_jagged(values, offsets)
+        # non-contiguous transposed
+        yield njt.transpose(1, 3)
+        # non-contiguous with holes
+        values = njt.values().detach().clone()
+        offsets = njt.offsets().detach().clone()
+        # subtract 1 to cause holes
+        lengths = (offsets.diff() - 1).detach().clone()
+        yield torch.nested.nested_tensor_from_jagged(
+            values=values,
+            offsets=offsets,
+            lengths=lengths,
+        )
+
+    for njt in _get_njts():
+        # single dim operation
+        for dim in range(njt.dim()):
+            # Operation on batch / ragged dim is never expected to work.
+            # TODO: Handle these via error_inputs.
+            if dim == 0 or dim == njt._ragged_idx:
+                continue
+
+            yield SampleInput(
+                _clone(njt),
+                kwargs={"dim": dim},
+                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
+            )
+
+        # multiple dim operation (pass no args)
+        yield SampleInput(
+            _clone(njt),
+            kwargs={"dim": dim},
+            name=f"{_describe_njt(njt)}: multiple dims",
+        )
+
+
+def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # It will never make sense to operate on the ragged dim.
+        # TODO: Handle this with error_inputs
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            continue
+
+        D = sample_input.input.size(sample_input.kwargs["dim"])
+        # sizes should multiply to be D
+        yield _update_sample(sample_input, {"sizes": [D, 1]})
+        yield _update_sample(sample_input, {"sizes": [1, D]})
+        if D % 2 == 0:
+            yield _update_sample(sample_input, {"sizes": [D // 2, 2]})
+            yield _update_sample(sample_input, {"sizes": [2, D // 2]})
+
+
+def sample_inputs_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        yield sample_input
+
+        last_dim_sample = _update_sample(sample_input, {"dim": -1})
+        last_dim_sample.name = (
+            f"{_describe_njt(last_dim_sample.input)}: add dim to the end"
+        )
+        # Tell the unbind reference how to canonicalize the dim kwargs
+        # This is necessary because unsqueeze() allows for a dim after
+        # the last dim to indicate an unsqueeze at the end.
+        last_dim_sample.input._ndim = last_dim_sample.input.dim() + 1
+        yield last_dim_sample
+
+
+def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
+    for sample in sample_inputs_elementwise_njt_binary(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        other = sample.args[0]
+        sample.args = ()
+        sample.kwargs["other"] = other
+        sample.kwargs["condition"] = sample.input > 0.0
+        sample.name = sample.name.replace("(", "(NT, ")
+        yield sample
+
+
+# === END OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
+
+
+# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs
+# (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name
+# separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary
+# to specify if they cannot be auto-generated for some reason. Try to keep these sorted
+# in alphabetical order!
+njt_sample_inputs = {
+    "bmm": sample_inputs_bmm,
+    "chunk": sample_inputs_chunk,
+    "clone": sample_inputs_clone,
+    "count_nonzero": partial(sample_inputs_njt_reduction, supports_keepdim=False),
+    "fill": sample_inputs_fill,
+    **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)},
+    "nn.functional.embedding": sample_inputs_nn_functional_embedding,
+    "nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag,
+    "nn.functional.linear": sample_inputs_nn_functional_linear,
+    "nn.functional.prelu": sample_inputs_nn_functional_prelu,
+    "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm,
+    "nn.functional.threshold": sample_inputs_nn_functional_threshold,
+    **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)},
+    "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0),
+    "to": sample_inputs_to,
+    "matmul": sample_inputs_matmul,
+    "masked_select": sample_inputs_masked_select,
+    "narrow": sample_inputs_narrow,
+    "index_put": sample_inputs_index_put,
+    # these two don't have ReductionOpInfo entries
+    "max.reduction_with_dim": sample_inputs_njt_reduction,
+    "min.reduction_with_dim": sample_inputs_njt_reduction,
+    "select": sample_inputs_select,
+    "split": sample_inputs_split,
+    "split_with_sizes": sample_inputs_split_with_sizes,
+    "squeeze": sample_inputs_squeeze,
+    "unflatten": sample_inputs_unflatten,
+    "unsqueeze": sample_inputs_unsqueeze,
+    "where": sample_inputs_where,
+}
+
+njt_references = {
+    "bmm": reference_bmm,
+    "chunk": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_chunk
+    ),
+    "count_nonzero": reduction_reference,
+    # these two don't have ReductionOpInfo entries
+    "max.reduction_with_dim": reduction_reference,
+    "min.reduction_with_dim": reduction_reference,
+    "narrow": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_narrow
+    ),
+    "select": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_select
+    ),
+    "split": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_split
+    ),
+    "split_with_sizes": partial(
+        unary_dimwise_reference,
+        batchwise_reference=batchwise_reference_split_with_sizes,
+    ),
+    "squeeze": unbind_reference,
+    "nn.functional.embedding_bag": reference_nn_functional_embedding_bag,
+    "unflatten": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_unflatten
+    ),
+    "unsqueeze": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_unsqueeze
+    ),
+}
+
+
+# Translates an OpInfo entry to one that operates on NJTs.
+def translate_opinfo(op):
+    new_op = copy(op)
+    new_op.supports_njt = True
+    # add some extra info for use in generating tests on the right subset of ops
+    new_op._extra_op_data = extra_op_data.get(op.full_name, ExtraOpData())
+
+    if op.full_name in njt_sample_inputs:
+        new_op.sample_inputs_func = njt_sample_inputs[op.full_name]
+        new_op.ref = njt_references.get(op.full_name, unbind_reference)
+    elif isinstance(op, UnaryUfuncInfo):
+        new_op.sample_inputs_func = partial(
+            sample_inputs_elementwise_njt_unary, op_kwargs=None
+        )
+        new_op.ref = unbind_reference
+    elif isinstance(op, BinaryUfuncInfo):
+        new_op.sample_inputs_func = partial(
+            sample_inputs_elementwise_njt_binary, op_kwargs=None
+        )
+        new_op.ref = unbind_reference
+    elif isinstance(op, ReductionOpInfo):
+        new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None)
+        new_op.ref = reduction_reference
+    # TODO: Translate the rest of the OpInfos
+    else:
+        new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name)
+        new_op.ref = unsupported_reference(op.full_name)
+        new_op.supports_njt = False
+
+    return new_op
+
+
+njt_op_db = [translate_opinfo(op) for op in op_db]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81efd19dbc6c804f066fd89a7068dce8ecf515f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py
@@ -0,0 +1,459 @@
+# mypy: ignore-errors
+
+import unittest
+from collections.abc import Callable
+from functools import partial
+from itertools import product
+
+import numpy
+
+import torch
+from torch.testing._internal.common_dtype import floating_types
+from torch.testing._internal.common_utils import TEST_SCIPY
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    ErrorInput,
+    OpInfo,
+    SampleInput,
+)
+
+
+if TEST_SCIPY:
+    import scipy.signal
+
+
+def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
+    r"""Base function used to create sample inputs for windows.
+
+    For additional required args you should use *args, as well as **kwargs for
+    additional keyword arguments.
+    """
+
+    # Remove include_conjugated_inputs from kwargs
+    kwargs.pop("include_conjugated_inputs", None)
+    # Tests window sizes up to 5 samples.
+    for size, sym in product(range(6), (True, False)):
+        yield SampleInput(
+            size,
+            *args,
+            sym=sym,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            **kwargs,
+        )
+
+
+def reference_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
+    r"""Reference inputs function to use for windows which have a common signature, i.e.,
+    window size and sym only.
+
+    Implement other special functions for windows that have a specific signature.
+    See exponential and gaussian windows for instance.
+    """
+    yield from sample_inputs_window(
+        op_info, device, dtype, requires_grad, *args, **kwargs
+    )
+
+    cases = (8, 16, 32, 64, 128, 256)
+
+    for size in cases:
+        yield SampleInput(size, sym=False)
+        yield SampleInput(size, sym=True)
+
+
+def reference_inputs_exponential_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"center": 4, "tau": 0.5}),
+        (16, {"center": 8, "tau": 2.5}),
+        (32, {"center": 16, "tau": 43.5}),
+        (64, {"center": 20, "tau": 3.7}),
+        (128, {"center": 62, "tau": 99}),
+        (256, {"tau": 10}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        kw["center"] = None
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"std": 0.1}),
+        (16, {"std": 1.2}),
+        (32, {"std": 2.1}),
+        (64, {"std": 3.9}),
+        (128, {"std": 4.5}),
+        (256, {"std": 10}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_kaiser_window(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"beta": 2}),
+        (16, {"beta": 12}),
+        (32, {"beta": 30}),
+        (64, {"beta": 35}),
+        (128, {"beta": 41.2}),
+        (256, {"beta": 100}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_general_cosine_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"a": [0.5, 0.5]}),
+        (16, {"a": [0.46, 0.54]}),
+        (32, {"a": [0.46, 0.23, 0.31]}),
+        (64, {"a": [0.5]}),
+        (128, {"a": [0.1, 0.8, 0.05, 0.05]}),
+        (256, {"a": [0.2, 0.2, 0.2, 0.2, 0.2]}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_general_hamming_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"alpha": 0.54}),
+        (16, {"alpha": 0.5}),
+        (32, {"alpha": 0.23}),
+        (64, {"alpha": 0.8}),
+        (128, {"alpha": 0.9}),
+        (256, {"alpha": 0.05}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def error_inputs_window(op_info, device, *args, **kwargs):
+    # Tests for windows that have a negative size
+    yield ErrorInput(
+        SampleInput(-1, *args, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="requires non-negative window length, got M=-1",
+    )
+
+    # Tests for window tensors that are not torch.strided, for instance, torch.sparse_coo.
+    yield ErrorInput(
+        SampleInput(
+            3,
+            *args,
+            layout=torch.sparse_coo,
+            device=device,
+            dtype=torch.float32,
+            **kwargs,
+        ),
+        error_type=ValueError,
+        error_regex="is implemented for strided tensors only, got: torch.sparse_coo",
+    )
+
+    # Tests for window tensors that are not floating point dtypes, for instance, torch.long.
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.long, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.int64",
+    )
+
+    # Tests for window tensors that are bfloat16
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.bfloat16, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.bfloat16",
+    )
+
+    # Tests for window tensors that are float16
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.float16, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.float16",
+    )
+
+
+def error_inputs_exponential_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, **kwargs)
+
+    # Tests for negative decay values.
+    yield ErrorInput(
+        SampleInput(3, tau=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Tau must be positive, got: -1 instead.",
+    )
+
+    # Tests for symmetric windows and a given center value.
+    yield ErrorInput(
+        SampleInput(3, center=1, sym=True, dtype=torch.float32, device=device),
+        error_type=ValueError,
+        error_regex="Center must be None for symmetric windows",
+    )
+
+
+def error_inputs_gaussian_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, std=0.5, **kwargs)
+
+    # Tests for negative standard deviations
+    yield ErrorInput(
+        SampleInput(3, std=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Standard deviation must be positive, got: -1 instead.",
+    )
+
+
+def error_inputs_kaiser_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, beta=12, **kwargs)
+
+    # Tests for negative beta
+    yield ErrorInput(
+        SampleInput(3, beta=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="beta must be non-negative, got: -1 instead.",
+    )
+
+
+def error_inputs_general_cosine_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, a=[0.54, 0.46], **kwargs)
+
+    # Tests for negative beta
+    yield ErrorInput(
+        SampleInput(3, a=None, dtype=torch.float32, device=device, **kwargs),
+        error_type=TypeError,
+        error_regex="Coefficients must be a list/tuple",
+    )
+
+    yield ErrorInput(
+        SampleInput(3, a=[], dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Coefficients cannot be empty",
+    )
+
+
+def reference_signal_window(fn: Callable):
+    r"""Wrapper for scipy signal window references.
+
+    Discards keyword arguments for window reference functions that don't have a matching signature with
+    torch, e.g., gaussian window.
+    """
+
+    def _fn(
+        *args,
+        dtype=numpy.float64,
+        device=None,
+        layout=torch.strided,
+        requires_grad=False,
+        **kwargs,
+    ):
+        r"""The unused arguments are defined to disregard those values"""
+        return fn(*args, **kwargs).astype(dtype)
+
+    return _fn
+
+
+def make_signal_windows_opinfo(
+    name: str,
+    ref: Callable,
+    sample_inputs_func: Callable,
+    reference_inputs_func: Callable,
+    error_inputs_func: Callable,
+    *,
+    skips: tuple[DecorateInfo, ...] = (),
+):
+    r"""Helper function to create OpInfo objects related to different windows."""
+    return OpInfo(
+        name=name,
+        ref=ref if TEST_SCIPY else None,
+        dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_func,
+        reference_inputs_func=reference_inputs_func,
+        error_inputs_func=error_inputs_func,
+        supports_out=False,
+        supports_autograd=False,
+        skips=(
+            # TODO: same as this?
+            # https://github.com/pytorch/pytorch/issues/81774
+            # also see: arange, new_full
+            # fails to match any schemas despite working in the interpreter
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestOperatorSignatures",
+                "test_get_torch_func_signature_exhaustive",
+            ),
+            # fails to match any schemas despite working in the interpreter
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # skip these tests since we have non tensor input
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestVmapOperatorsOpInfo",
+                "test_vmap_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestVmapOperatorsOpInfo",
+                "test_op_has_batch_rule",
+            ),
+            DecorateInfo(
+                unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+            *skips,
+        ),
+    )
+
+
+op_db: list[OpInfo] = [
+    make_signal_windows_opinfo(
+        name="signal.windows.hamming",
+        ref=reference_signal_window(scipy.signal.windows.hamming)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.hann",
+        ref=reference_signal_window(scipy.signal.windows.hann) if TEST_SCIPY else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.bartlett",
+        ref=reference_signal_window(scipy.signal.windows.bartlett)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.blackman",
+        ref=reference_signal_window(scipy.signal.windows.blackman)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.cosine",
+        ref=reference_signal_window(scipy.signal.windows.cosine)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.exponential",
+        ref=reference_signal_window(scipy.signal.windows.exponential)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, tau=2.78),
+        reference_inputs_func=partial(reference_inputs_exponential_window, tau=2.78),
+        error_inputs_func=error_inputs_exponential_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.gaussian",
+        ref=reference_signal_window(scipy.signal.windows.gaussian)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, std=1.92),
+        reference_inputs_func=partial(reference_inputs_gaussian_window, std=1.92),
+        error_inputs_func=error_inputs_gaussian_window,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.kaiser",
+        ref=reference_signal_window(scipy.signal.windows.kaiser)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, beta=12.0),
+        reference_inputs_func=partial(reference_inputs_kaiser_window, beta=12.0),
+        error_inputs_func=error_inputs_kaiser_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.general_cosine",
+        ref=reference_signal_window(scipy.signal.windows.general_cosine)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, a=[0.54, 0.46]),
+        reference_inputs_func=partial(
+            reference_inputs_general_cosine_window, a=[0.54, 0.46]
+        ),
+        error_inputs_func=error_inputs_general_cosine_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.general_hamming",
+        ref=reference_signal_window(scipy.signal.windows.general_hamming)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, alpha=0.54),
+        reference_inputs_func=partial(
+            reference_inputs_general_hamming_window, alpha=0.54
+        ),
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.nuttall",
+        ref=reference_signal_window(scipy.signal.windows.nuttall)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..200a3ad9ed902962edcc2da0153117e83d64131a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py
@@ -0,0 +1,928 @@
+# mypy: ignore-errors
+
+import os
+
+import torch
+from torch.testing import make_tensor  # noqa: F401
+from torch.testing._internal.opinfo.core import (  # noqa: F401
+    BinaryUfuncInfo,
+    ErrorInput,
+    generate_elementwise_binary_tensors,
+    ReductionOpInfo,
+    sample_inputs_reduction,
+    SampleInput,
+)
+
+
+def _check_validate(op_info, sample):
+    def _check_fail(sample):
+        try:
+            op_info(
+                sample.sample_input.input,
+                *sample.sample_input.args,
+                **sample.sample_input.kwargs,
+            )
+        except sample.error_type:
+            pass
+        except Exception as msg:
+            raise AssertionError(  # noqa: B904
+                f"{op_info.name} on {sample.sample_input=} expected exception "
+                f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}"
+            )
+        else:
+            raise AssertionError(
+                f"{op_info.name} on {sample.sample_input=} expected exception "
+                f"{sample.error_type}: {sample.error_regex}, got none."
+            )
+
+    def _check_success(sample):
+        try:
+            op_info(sample.input, *sample.args, **sample.kwargs)
+        except Exception as msg:
+            raise AssertionError(  # noqa: B904
+                f"{op_info.name} on {sample=} expected to succeed "
+                f", got {type(msg).__name__}: {msg}"
+            )
+
+    if isinstance(sample, ErrorInput):
+        _check_fail(sample)
+    else:
+        _check_success(sample)
+
+
+def _sample_inputs_sparse(
+    sample_inputs,
+    maybe_failing_sample_inputs,
+    validate_sample_input,
+    op_info,
+    *args,
+    **kwargs,
+):
+    check_validate = (
+        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
+    )
+    for sample in sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, SampleInput):
+            yield sample
+        # Error inputs are handled in error_inputs_sparse
+
+    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, SampleInput):
+            yield sample
+
+
+def _error_inputs_sparse(
+    maybe_failing_sample_inputs, validate_sample_input, op_info, *args, **kwargs
+):
+    check_validate = (
+        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
+    )
+    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, ErrorInput):
+            yield sample
+        # Sample inputs are handled in sample_inputs_sparse
+
+
+def _apply_requires_grad_to_samples(sample_inputs):
+    """Decorator to _maybe_failing_sample_inputs_... generator functions
+    that clones and sets requires_grad argument to tensors in sample
+    input arguments. This is needed when the generated samples share
+    tensor instances.
+    """
+
+    def wrapper(op_info, device, dtype, requires_grad, layout, **kwargs):
+        def apply_requires_grad(x):
+            if (
+                not isinstance(x, torch.Tensor)
+                or x.requires_grad
+                or not requires_grad
+                or not (x.is_floating_point() or x.is_complex())
+            ):
+                return x
+            return x.detach().clone().requires_grad_(requires_grad)
+
+        if requires_grad:
+            for sample_input in sample_inputs(
+                op_info, device, dtype, requires_grad, layout, **kwargs
+            ):
+                yield sample_input.transform(apply_requires_grad)
+        else:
+            yield from sample_inputs(
+                op_info, device, dtype, requires_grad, layout, **kwargs
+            )
+
+    return wrapper
+
+
+def sample_inputs_sparse_reduction(
+    op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs
+):
+    """Sample inputs for reduction operations on sparse tensors."""
+    layout_name = str(layout).split(".", 1)[-1].rsplit("_coo", 1)[0]
+    op_supports_layout = getattr(op_info, "supports_" + layout_name)
+    if not op_supports_layout:
+        return
+
+    for sample_input in sample_inputs_reduction(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        if sample_input.input.ndim == 0:
+            # scalar sparse tensors are not supported
+            continue
+
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if sample_input.input.ndim < 2:
+                # conversion to sparse compressed tensors requires at
+                # least 2 dimensional tensors
+                continue
+            if sample_input.input.ndim > 2 and (sample_input.input == 0).any():
+                # Skip batched sparse compressed samples that contain
+                # explicit zeros because to_sparse(layout=..) will
+                # fail, see gh-98495.
+                # TODO: remove this if-block after gh-98495 is fixed.
+                continue
+
+        if layout in {torch.sparse_bsr, torch.sparse_bsc} and blocksize is None:
+            blocksize = (1, 1)
+
+        yield SampleInput(
+            sample_input.input.detach()
+            .to_sparse(layout=layout, blocksize=blocksize)
+            .requires_grad_(requires_grad),
+            args=sample_input.args,
+            kwargs=sample_input.kwargs,
+        )
+
+        if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex):
+            # uncoalesced samples
+            inp = sample_input.input.detach().to_sparse(layout=layout)
+            inp = torch.sparse_coo_tensor(
+                inp.indices().repeat(1, 2),
+                inp.values().repeat(2),
+                inp.shape,
+                dtype=inp.dtype,
+                device=inp.device,
+            )
+            assert not inp.is_coalesced()
+            yield SampleInput(
+                inp.requires_grad_(requires_grad),
+                args=sample_input.args,
+                kwargs=sample_input.kwargs,
+            )
+
+        if sample_input.input.ndim > 2:
+            # hybrid samples
+            yield SampleInput(
+                sample_input.input.detach()
+                .to_sparse(
+                    layout=layout,
+                    blocksize=blocksize,
+                    dense_dim=sample_input.input.ndim - 2,
+                )
+                .requires_grad_(requires_grad),
+                args=sample_input.args,
+                kwargs=sample_input.kwargs,
+            )
+
+
+def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=False):
+    """Return the specified sample when it is valid and supported by the
+    operation. Otherwise, return the sample as ErrorInput instance.
+
+    When check_validate is True, the result is validated against
+    calling the op on the sample.
+    """
+    UNSPECIFIED = object()
+    if op_info.name == "sum":
+        sample = _validate_sample_input_sparse_reduction_sum(sample)
+
+    if op_info.name == "masked.sum":
+        mask = sample.kwargs.get("mask", UNSPECIFIED)
+        if (
+            mask not in {None, UNSPECIFIED}
+            and mask.ndim > 2
+            and mask.layout is torch.strided
+            and (mask == 0).any()
+        ):
+            # TODO: remove this if-block after gh-98495 is fixed.
+            sample = ErrorInput(
+                sample,
+                error_regex="Expect the same number of specified elements per batch.",
+            )
+        elif not sample.kwargs.get("keepdim"):
+            sample = ErrorInput(
+                sample,
+                error_type=(AssertionError, RuntimeError),
+                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
+            )
+        elif mask is UNSPECIFIED:
+            sample = ErrorInput(
+                sample,
+                error_type=ValueError,
+                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
+            )
+        elif sample.input.ndim > 2:
+            sample = ErrorInput(
+                sample,
+                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
+            )
+
+    if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}:
+        t_inp = sample.input
+        mask = sample.kwargs.get("mask")
+        if (
+            mask is not None
+            and mask.ndim > 2
+            and mask.layout is torch.strided
+            and (mask == 0).any()
+        ):
+            # TODO: remove this if-block after gh-98495 is fixed.
+            sample = ErrorInput(
+                sample,
+                error_regex="Expect the same number of specified elements per batch.",
+            )
+        elif mask is None:
+            sample = ErrorInput(
+                sample,
+                error_type=ValueError,
+                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
+            )
+        elif (
+            mask.layout is sample.input.layout
+            and mask.ndim > 2
+            and op_info.name == "masked.mean"
+        ):
+            sample = ErrorInput(
+                sample,
+                error_type=TypeError,
+                error_regex=(
+                    "where[(][)] received an invalid combination of arguments"
+                    " - got [(]Tensor, Tensor, NoneType[)]"
+                ),
+            )
+        elif not sample.kwargs.get("keepdim"):
+            sample = ErrorInput(
+                sample,
+                error_type=(AssertionError, RuntimeError),
+                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
+            )
+        elif (
+            sample.input.ndim > 2
+            and (sample.kwargs.get("dim") not in {0, 1})
+            and mask.ndim > 2
+            and mask.layout is not torch.strided
+        ):
+            if sample.kwargs.get("dim") == (0, -1):
+                sample = ErrorInput(
+                    sample,
+                    error_regex="tensor dimensionality must be sum of batch, base, and dense dimensionalities",
+                )
+            elif op_info.name == "masked.prod":
+                sample = ErrorInput(
+                    sample,
+                    error_regex="input_dim == 2 INTERNAL ASSERT FAILED at",
+                )
+            else:
+                sample = ErrorInput(
+                    sample,
+                    error_type=AssertionError,
+                    error_regex="Sparse CSR tensors are 2D and only support reduction along dim 0 or 1.",
+                )
+        elif sample.input.ndim > 2:
+            sample = ErrorInput(
+                sample,
+                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
+            )
+        elif (
+            mask.layout is t_inp.layout
+            and mask._nnz() != t_inp._nnz()
+            and t_inp.dense_dim() > 0
+        ):
+            sample = ErrorInput(
+                sample,
+                error_regex="Index tensor must have the same number of dimensions as src tensor",
+            )
+
+    if check_validate:
+        _check_validate(op_info, sample)
+
+    return sample
+
+
+def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False):
+    # NOTE: When fixing a failing sample case, remove the
+    #       corresponding if-block
+    t_inp, t_kwargs = sample.input, sample.kwargs
+    dim = t_kwargs.get("dim")
+    keepdim = t_kwargs.get("keepdim")
+    layout = t_inp.layout
+    if isinstance(dim, (int, list, tuple)):
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
+                return ErrorInput(
+                    sample,
+                    error_regex=(
+                        "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout"
+                    ),
+                )
+            if layout in {torch.sparse_csr, torch.sparse_csc} and not keepdim:
+                return ErrorInput(
+                    sample,
+                    error_regex=(
+                        "reduction operations on CSR tensors with keepdim=False is unsupported"
+                    ),
+                )
+            if t_inp.dim() != 2:
+                return ErrorInput(
+                    sample,
+                    error_regex=("input_dim == 2 INTERNAL ASSERT"),
+                )
+            if layout == torch.sparse_csr:
+                if t_inp.dtype == torch.bool:
+                    return ErrorInput(
+                        sample,
+                        error_regex=("_sparse_csr_sum_cpu not implemented for 'Bool'"),
+                    )
+                if t_inp.dtype == torch.complex32:
+                    return ErrorInput(
+                        sample,
+                        error_regex=(
+                            "_sparse_csr_sum_cuda not implemented for 'ComplexHalf'"
+                        ),
+                    )
+    return sample
+
+
+def _maybe_failing_sample_inputs_sparse_reduction_sum(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Generator of samples that are known to fail or that were failing in past."""
+    # NOTE: When fixing a failing case, remove the Exception comment
+    #       but keep the `yield sample` statement.
+    if layout in [
+        torch.sparse_csr,
+        torch.sparse_csc,
+    ]:
+        # NotImplementedError: Could not run 'aten::sum.IntList_out' with arguments from the 'SparseCsrCPU' backend.
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0, keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,), keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+
+        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+
+    if layout in [
+        torch.sparse_bsr,
+        torch.sparse_bsc,
+    ]:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(2, 2))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0, keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,), keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1), dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+
+        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+
+
+def sample_inputs_sparse_reduction_sum(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for sum on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        sample_inputs_sparse_reduction,
+        _maybe_failing_sample_inputs_sparse_reduction_sum,
+        _validate_sample_input_sparse_reduction,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_reduction_sum(op_info, device, layout, **kwargs):
+    """Error inputs for sum on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_reduction_sum,
+        _validate_sample_input_sparse_reduction,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def sample_inputs_sparse_elementwise_binary_operation(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for elementwise binary operations on sparse tensors.
+
+    The samples include regular, zero-sized, batched, and hybrid
+    sparse tensors as well as rhs scalars. All tensors are full tensors.
+    """
+
+    def _to_sparse(tensor, **kwargs):
+        return tensor.detach().to_sparse(**kwargs).requires_grad_(requires_grad)
+
+    for sample_input in generate_elementwise_binary_tensors(
+        op_info,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=True,
+        **kwargs,
+    ):
+        lhs, rhs = sample_input.input, sample_input.args[0]
+        min_dense_dim = 0
+        max_dense_dim = lhs.ndim - 1
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if lhs.ndim < 2:
+                # sparse compressed tensors sparse_dim must be 2
+                continue
+            max_dense_dim = lhs.ndim - 2
+
+        for dense_dim in range(min_dense_dim, max_dense_dim + 1):
+            if layout in {torch.sparse_bsr, torch.sparse_bsc}:
+                blocksizes = [(1, 1)]
+                if lhs.numel() > 0:
+                    blocksizes.append(
+                        (
+                            lhs.shape[lhs.ndim - 2 - dense_dim],
+                            lhs.shape[lhs.ndim - 1 - dense_dim],
+                        )
+                    )
+            else:
+                blocksizes = [None]
+            for blocksize in blocksizes:
+                to_sparse_kwargs = dict(
+                    layout=layout, dense_dim=dense_dim, blocksize=blocksize
+                )
+                lhs_sparse = _to_sparse(lhs, **to_sparse_kwargs)
+                rhs_sparse = _to_sparse(rhs, **to_sparse_kwargs)
+                # op(sparse, sparse)
+                yield SampleInput(
+                    lhs_sparse,
+                    args=(rhs_sparse, *sample_input.args[1:]),
+                    kwargs=sample_input.kwargs,
+                )
+                # op(sparse, scalar)
+                yield SampleInput(
+                    lhs_sparse,
+                    args=(
+                        make_tensor(
+                            (), dtype=dtype, device=device, requires_grad=requires_grad
+                        ),
+                        *sample_input.args[1:],
+                    ),
+                    kwargs=sample_input.kwargs,
+                )
+
+
+def _validate_sample_input_elementwise_binary_sparse_mul(sample):
+    # NOTE: When fixing a failing sample case, remove the
+    #       corresponding if-block
+    t_inp, t_args = sample.input, sample.args
+    batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
+    layout = t_inp.layout
+    dtype = t_inp.dtype
+    if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "coo_to_sparse_csr: conversion from Sparse to SparseCsr for input"
+                " tensors with sparse_dim[(][)]!=2 is not supported"
+            ),
+        )
+    elif layout is torch.sparse_csc and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample, error_regex="Expected result Tensor to be of format CSR"
+        )
+    elif layout is torch.sparse_bsr and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsr",
+        )
+    elif layout is torch.sparse_bsc and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsc",
+        )
+    elif (
+        layout is torch.sparse_coo
+        and dtype is torch.bool
+        and t_args[0].ndim > 0
+        and t_inp.is_cpu
+        and t_inp.numel() > 0
+        and t_inp.dense_dim() > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Bool'"
+        )
+    elif (
+        layout in {torch.sparse_coo, torch.sparse_csr}
+        and dtype is torch.bool
+        and t_inp._nnz() > 0
+        and t_args[0].ndim > 0
+        and t_inp.is_cpu
+        and t_inp.numel() > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"mul_out_sparse\" not implemented for 'Bool'"
+        )
+    elif (
+        layout is torch.sparse_csr
+        and t_args[0].layout is torch.strided
+        and 0 < t_args[0].ndim
+        and t_args[0].ndim < t_inp.ndim
+    ):
+        return ErrorInput(
+            sample, error_regex="sparse_mask_sparse_csr expects self to be 2D"
+        )
+    elif layout is torch.sparse_csr and (
+        (t_args[0].layout is torch.strided and 0 < t_args[0].ndim)
+        or (t_args[0].layout is layout and t_inp.shape != t_args[0].shape)
+    ):
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "expects sparse inputs with equal dimensionality, number of sparse dimensions,"
+                " and shape of sparse dimensions"
+            ),
+        )
+    elif (
+        layout is torch.sparse_csr
+        and t_inp.dense_dim() > 0
+        and t_inp._nnz() > 0
+        and t_inp.is_cpu
+        and dtype is torch.float16
+        and t_args[0].ndim > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Half'"
+        )
+    return sample
+
+
+@_apply_requires_grad_to_samples
+def _maybe_failing_sample_inputs_sparse_elementwise_binary_mul(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Generator of samples that are known to fail or that were failing in past."""
+    # NOTE: When fixing a failing case, remove the Exception comment
+    #       but keep the `yield sample` statement.
+
+    blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
+    regular = torch.tensor([[1, 2], [3, 4]], device=device, dtype=dtype).to_sparse(
+        layout=layout, dense_dim=0, blocksize=blocksize
+    )
+    batch = torch.tensor(
+        [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], device=device, dtype=dtype
+    ).to_sparse(layout=layout, dense_dim=0, blocksize=blocksize)
+    hybrid = torch.tensor(
+        [[[1], [2]], [[3], [4]]], device=device, dtype=dtype
+    ).to_sparse(layout=layout, dense_dim=1, blocksize=blocksize)
+
+    if layout is torch.sparse_csr:
+        # RuntimeError: crow_indices is supposed to be a vector, but got 2 dimensional tensor
+        yield SampleInput(batch, args=(batch,))
+        # RuntimeError: Only tensors with two sparse dimensions can be
+        # converted to the SparseCsr layout, got self with 3 sparse
+        # dimensions.
+        yield SampleInput(
+            torch.zeros_like(hybrid).requires_grad_(requires_grad),
+            args=(torch.zeros_like(hybrid).requires_grad_(requires_grad),),
+        )
+        if dtype is torch.complex32:
+            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
+            yield SampleInput(regular, args=(regular,))
+        if dtype is torch.bool and regular.is_cpu:
+            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
+            yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_csc:
+        # RuntimeError: Expected result Tensor to be of format CSR
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_bsr:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_bsc:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsc
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_coo:
+        if dtype is torch.complex32:
+            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
+            yield SampleInput(regular, args=(regular,))
+        if dtype is torch.bool and regular.is_cpu:
+            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
+            yield SampleInput(regular, args=(regular,))
+        if dtype in {torch.bool, torch.float16} and regular.is_cpu:
+            # RuntimeError: "addcmul_cpu_out" not implemented for '(Bool|Half)'
+            yield SampleInput(hybrid, args=(hybrid,))
+
+
+def _validate_sample_input_sparse_elementwise_binary_operation(
+    op_info, sample, check_validate=False
+):
+    if op_info.name == "mul":
+        sample = _validate_sample_input_elementwise_binary_sparse_mul(sample)
+
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def sample_inputs_sparse_mul(op_info, device, dtype, requires_grad, layout, **kwargs):
+    """Sample inputs for mul operation on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        sample_inputs_sparse_elementwise_binary_operation,
+        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
+        _validate_sample_input_sparse_elementwise_binary_operation,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
+    """Error inputs for mul operation on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
+        _validate_sample_input_sparse_elementwise_binary_operation,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def _sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    from torch.testing._internal.common_utils import TestCase
+
+    for tensor in TestCase().generate_simple_inputs(
+        layout,
+        device=device,
+        dtype=dtype,
+        enable_batch=True,
+        enable_hybrid=True,
+        enable_zero_sized=True,
+        enable_non_contiguous_indices=False,
+        enable_non_contiguous_values=False,
+    ):
+        yield SampleInput(tensor, args=(), kwargs={})
+        yield SampleInput(
+            tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
+        )
+
+        if dtype is not torch.float64:
+            yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))
+
+        if torch.cuda.is_available():
+            other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
+            yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))
+
+        if layout is torch.sparse_csr:
+            other_layout = torch.sparse_csc
+        elif layout is torch.sparse_csc:
+            other_layout = torch.sparse_csr
+        elif layout is torch.sparse_bsr:
+            other_layout = torch.sparse_bsc
+        elif layout is torch.sparse_bsc:
+            other_layout = torch.sparse_bsr
+        else:
+            other_layout = torch.strided
+        yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))
+
+        if layout is not torch.sparse_coo:
+            yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))
+
+
+def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
+    if (
+        sample.input.layout
+        in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }
+        and op_info.name != "zeros_like"
+    ):
+        if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
+            return ErrorInput(
+                sample,
+                error_regex=(
+                    "empty_like with different sparse layout is not supported"
+                    " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
+                ),
+            )
+    if sample.input.layout is torch.sparse_coo:
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
+            ),
+        )
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def _maybe_failing_sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    if torch.cuda.is_available() and layout is not torch.sparse_coo:
+        other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
+        if layout is torch.sparse_csr:
+            other_layout = torch.sparse_csc
+        elif layout is torch.sparse_csc:
+            other_layout = torch.sparse_csr
+        elif layout is torch.sparse_bsr:
+            other_layout = torch.sparse_bsc
+        elif layout is torch.sparse_bsc:
+            other_layout = torch.sparse_bsr
+        else:
+            other_layout = torch.strided
+
+        blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
+
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
+                layout=layout, blocksize=blocksize
+            ),
+            kwargs=dict(device=other_device),
+        )
+
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
+                layout=layout, blocksize=blocksize
+            ),
+            kwargs=dict(layout=other_layout),
+        )
+
+
+def sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for like-functions on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        _sample_inputs_sparse_like_fns,
+        _maybe_failing_sample_inputs_sparse_like_fns,
+        _validate_sample_input_sparse_like_fns,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
+    """Error inputs for like-functions on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_like_fns,
+        _validate_sample_input_sparse_like_fns,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
+    if op_info.name == "to_sparse":
+        if (
+            sample.input.layout
+            in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
+            and len(sample.args) == 1
+            and isinstance(sample.args[0], int)
+            and sample.args[0] != 2
+        ):
+            sample = ErrorInput(
+                sample,
+                error_regex="sparse dim argument must be 2 for sparse_compressed_to_sparse",
+            )
+
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def validate_sample_input_sparse(op_info, sample, check_validate=False):
+    """Return the specified sample when it is valid and supported by the
+    operation. Otherwise, return the sample as ErrorInput instance.
+
+    When check_validate is True, the result is validated against
+    calling the op on the sample.
+    """
+    if isinstance(op_info, ReductionOpInfo):
+        return _validate_sample_input_sparse_reduction(
+            op_info, sample, check_validate=check_validate
+        )
+    elif isinstance(op_info, BinaryUfuncInfo):
+        return _validate_sample_input_sparse_elementwise_binary_operation(
+            op_info, sample, check_validate=check_validate
+        )
+    else:
+        return _validate_sample_input_sparse_default(
+            op_info, sample, check_validate=check_validate
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py
new file mode 100644
index 0000000000000000000000000000000000000000..47cbcb1fb4268aa8261e38cd6b197a15c39a4428
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py
@@ -0,0 +1,805 @@
+# mypy: ignore-errors
+
+import unittest
+from functools import partial
+from itertools import product
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import (
+    precisionOverride,
+    tol,
+    toleranceOverride,
+)
+from torch.testing._internal.common_dtype import all_types_and, floating_types
+from torch.testing._internal.common_utils import TEST_SCIPY, torch_to_numpy_dtype_dict
+from torch.testing._internal.opinfo.core import (
+    BinaryUfuncInfo,
+    DecorateInfo,
+    L,
+    NumericsFilter,
+    OpInfo,
+    S,
+    SampleInput,
+    UnaryUfuncInfo,
+)
+from torch.testing._internal.opinfo.refs import (
+    ElementwiseBinaryPythonRefInfo,
+    ElementwiseUnaryPythonRefInfo,
+)
+from torch.testing._internal.opinfo.utils import (
+    np_unary_ufunc_integer_promotion_wrapper,
+)
+
+
+if TEST_SCIPY:
+    import scipy.special
+
+
+# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
+#       supports `exclude` argument.
+#       For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
+def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs):
+    exclude_zero = requires_grad and op_info.op is torch.special.i0e
+    make_arg = partial(
+        make_tensor,
+        dtype=dtype,
+        device=device,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    yield SampleInput(make_arg((S,)))
+    yield SampleInput(make_arg(()))
+
+    if requires_grad and not exclude_zero:
+        # Special Case for gradient
+        # Sample with `0` in the input
+        t = make_arg((S,))
+        t[0] = 0
+
+        yield SampleInput(t)
+
+
+def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        # TODO: eliminate low after gh-106692 is fixed:
+        low=(1 if dtype in {torch.int32, torch.int64} else None),
+        dtype=dtype,
+        requires_grad=requires_grad,
+    )
+    tensor_shapes = ((S, S), ())
+    ns = (1, 2, 3, 4, 5)
+
+    for shape, n in product(tensor_shapes, ns):
+        yield SampleInput(make_arg(shape), args=(n,))
+
+
+def reference_polygamma(x, n):
+    # WEIRD `scipy.special.polygamma` behavior
+    # >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
+    # dtype('float64')
+    # >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
+    # dtype('float32')
+    #
+    # Thus we cast output to the default torch dtype or preserve double
+    result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
+    if x.dtype == np.double:
+        result_dtype = np.double
+    return scipy.special.polygamma(n, x).astype(result_dtype)
+
+
+def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
+    low, _ = op_info.domain
+
+    if requires_grad:
+        low = 0 + op_info._domain_eps
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg((L,)))
+    yield SampleInput(make_arg(()))
+
+
+def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
+    for shape in ((L,), (1, 0, 3), ()):
+        yield SampleInput(
+            make_tensor(
+                shape,
+                device=device,
+                dtype=dtype,
+                low=-5,
+                requires_grad=requires_grad,
+            ),
+        )
+
+
+op_db: list[OpInfo] = [
+    UnaryUfuncInfo(
+        "special.i0e",
+        aten_name="special_i0e",
+        ref=scipy.special.i0e if TEST_SCIPY else None,
+        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_i0_i1,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.i1",
+        aten_name="special_i1",
+        ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
+        if TEST_SCIPY
+        else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        backward_dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_i0_i1,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1e-4, rtol=0),
+                        torch.bool: tol(atol=1e-4, rtol=0),
+                    }
+                )
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Incorrect result!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=(torch.int8,),
+            ),
+        ),
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.i1e",
+        aten_name="special_i1e",
+        ref=scipy.special.i1e if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        backward_dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_i0_i1,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.ndtr",
+        aten_name="special_ndtr",
+        decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
+        ref=scipy.special.ndtr if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # Dispatch stub: unsupported device typemeta
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="meta",
+            ),
+        ),
+    ),
+    # A separate OpInfo entry for special.polygamma is needed to reorder the arguments
+    # for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939
+    UnaryUfuncInfo(
+        "special.polygamma",
+        op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs),
+        variant_test_name="special_polygamma_n_0",
+        ref=reference_polygamma if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_polygamma,
+        skips=(
+            # lambda impl
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+        ),
+        sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
+        # polygamma functions have multiple singularities at x having non-positive integer value
+        reference_numerics_filter=NumericsFilter(
+            condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1
+        ),
+    ),
+    BinaryUfuncInfo(
+        "special.xlog1py",
+        aten_name="special_xlog1py",
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        promotes_int_to_float=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_one_python_scalar=True,
+        # We don't test -1 as the gradient will be NaN and it'll break
+        rhs_make_tensor_kwargs=dict(low=-0.99),
+    ),
+    BinaryUfuncInfo(
+        "special.zeta",
+        aten_name="special_zeta",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        supports_autograd=False,
+        supports_one_python_scalar=True,
+        skips=(
+            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+    # TODO: FIXME
+    # OpInfo entry to verify the gradient formula of `other`/`q`
+    # BinaryUfuncInfo('special.zeta',
+    #                 op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
+    #                 aten_name='special_zeta',
+    #                 variant_test_name='grad',
+    #                 dtypes=all_types_and(torch.bool),
+    #                 promotes_int_to_float=True,
+    #                 supports_autograd=True,
+    #                 supports_rhs_python_scalar=False,
+    #                 decorators=[
+    #                     # Derivative wrt first tensor not implemented
+    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
+    #                                  "test_floating_inputs_are_differentiable")
+    #                 ],
+    #                 skips=(
+    #                     # Lambda doesn't work in JIT test
+    #                     # AssertionError: JIT Test does not execute any logic
+    #                     DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"),
+    #                 )),
+    UnaryUfuncInfo(
+        "special.entr",
+        ref=scipy.special.entr if TEST_SCIPY else None,
+        aten_name="special_entr",
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=[torch.bfloat16, torch.float16],
+            ),
+        ),
+        supports_inplace_autograd=False,
+        sample_inputs_func=sample_inputs_entr,
+    ),
+    UnaryUfuncInfo(
+        "special.ndtri",
+        ref=scipy.special.ndtri if TEST_SCIPY else None,
+        domain=(0, 1),
+        aten_name="special_ndtri",
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.log_ndtr",
+        aten_name="special_log_ndtr",
+        ref=scipy.special.log_ndtr if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.erfcx",
+        ref=scipy.special.erfcx if TEST_SCIPY else None,
+        aten_name="special_erfcx",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=0, rtol=4e-6),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_erfcx,
+    ),
+    UnaryUfuncInfo(
+        "special.airy_ai",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+            ),
+        ),
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_j0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.j0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_j1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.j1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_y0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.y0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_y1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.y1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_t",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_u",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_v",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_w",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.hermite_polynomial_h",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: inf
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+            # Too slow
+            DecorateInfo(
+                unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.hermite_polynomial_he",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: inf
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.laguerre_polynomial_l",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+            # Too slow
+            DecorateInfo(
+                unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.legendre_polynomial_p",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_i0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.i0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_i1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.i1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_k0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_k1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.scaled_modified_bessel_k0",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k0e if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.scaled_modified_bessel_k1",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k1e if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_t",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_u",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_v",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_w",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.spherical_bessel_j0",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None,
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Scipy doesn't support bool inputs to spherical_bessel_j0"
+                ),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_normal",
+                dtypes=(torch.bool,),
+            ),
+        ),
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    #
+    # Elementwise Unary Special OpInfos
+    #
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.bessel_j0",
+        torch_opinfo_name="special.bessel_j0",
+        op_db=op_db,
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.bessel_j1",
+        torch_opinfo_name="special.bessel_j1",
+        op_db=op_db,
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.entr",
+        torch_opinfo_name="special.entr",
+        op_db=op_db,
+        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=[torch.bfloat16, torch.float16],
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.erfcx",
+        torch_opinfo_name="special.erfcx",
+        op_db=op_db,
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=0, rtol=4e-6),
+                }
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i0e",
+        torch_opinfo_name="special.i0e",
+        op_db=op_db,
+        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i1",
+        torch_opinfo_name="special.i1",
+        op_db=op_db,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1e-4, rtol=0),
+                        torch.bool: tol(atol=1e-4, rtol=0),
+                    }
+                )
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Incorrect result!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=(torch.int8,),
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i1e",
+        torch_opinfo_name="special.i1e",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.log_ndtr",
+        torch_opinfo_name="special.log_ndtr",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.ndtr",
+        torch_opinfo_name="special.ndtr",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.ndtri",
+        torch_opinfo_name="special.ndtri",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.spherical_bessel_j0",
+        torch_opinfo_name="special.spherical_bessel_j0",
+        op_db=op_db,
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Scipy doesn't support bool inputs to spherical_bessel_j0"
+                ),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_normal",
+                dtypes=(torch.bool,),
+            ),
+        ),
+    ),
+    #
+    # Elementwise Binary Special OpInfos
+    #
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.special.zeta",
+        torch_opinfo_name="special.zeta",
+        supports_one_python_scalar=True,
+        op_db=op_db,
+        skips=(
+            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+]
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/static_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/static_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a031b0d8f6e685517b7ac51c236e23835501cd9
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/static_module.py
@@ -0,0 +1,27 @@
+# mypy: allow-untyped-defs
+# Owner(s): ["module: unknown"]
+
+import torch
+
+
+class StaticModule:
+    def __init__(self, scripted):
+        # this is an nn.Module
+        if hasattr(scripted, "_c"):
+            self.static_module = torch._C._jit_to_static_module(scripted._c)
+        else:
+            self.static_module = torch._C._jit_to_static_module(scripted.graph)
+
+    def __call__(self, *args, **kwargs):
+        return self.static_module(*args, **kwargs)
+
+    def benchmark(self, args, kwargs, warmup_runs, main_runs):
+        self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
+
+    def runAsync(self, args, kwargs):
+        return self.static_module.runAsync(args, kwargs)
+
+    def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
+        return self.static_module.benchmark_individual_ops(
+            args, kwargs, warmup_runs, main_runs
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py
new file mode 100644
index 0000000000000000000000000000000000000000..228f98139fea5adc1078cdcf7ede2a4adc4d6ede
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py
@@ -0,0 +1,78 @@
+# mypy: ignore-errors
+from typing import Any, Optional
+
+import torch
+import torch.utils._pytree as pytree
+from torch._subclasses.fake_tensor import is_fake
+from torch.testing._internal.two_tensor import TwoTensor
+from torch.utils._python_dispatch import return_and_correct_aliasing
+
+
+class WrapperSubclass(torch.Tensor):
+    @staticmethod
+    def __new__(cls, a, outer_size=None, outer_stride=None):
+        if outer_size is None:
+            outer_size = a.size()
+        if outer_stride is None:
+            outer_stride = a.stride()
+
+        kwargs = {}
+        kwargs["strides"] = outer_stride
+        kwargs["storage_offset"] = a.storage_offset()
+        kwargs["device"] = a.device
+        kwargs["layout"] = a.layout
+        kwargs["requires_grad"] = a.requires_grad
+        kwargs["dtype"] = a.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, outer_size, **kwargs)
+
+        return out
+
+    def __init__(self, a, outer_size=None, outer_stride=None):
+        self.a = a
+
+    def __repr__(self):
+        return f"WrapperSubclass({repr(self.a)})"
+
+    def __tensor_flatten__(self):
+        return ["a"], None
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is None
+        a = inner_tensors["a"]
+        if is_fake(a):
+            assert outer_size is not None
+            assert outer_stride is not None
+        return WrapperSubclass(a, outer_size, outer_stride)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+        args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args)
+
+        kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs)
+
+        out_a = func(*args_a, **kwargs_a)
+        out_a_flat, spec = pytree.tree_flatten(out_a)
+        out_flat = [
+            WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a
+            for o_a in out_a_flat
+        ]
+        out = pytree.tree_unflatten(out_flat, spec)
+        from torch._higher_order_ops.cond import cond_op
+
+        if func is cond_op:
+            return out
+        else:
+            return return_and_correct_aliasing(func, args, kwargs, out)
+
+    def __coerce_same_metadata_as_tangent__(
+        self, expected_metadata: Any, expected_type: Optional[type] = None
+    ):
+        if expected_type is type(self.a):
+            return self.a
+        elif expected_type is TwoTensor:
+            return TwoTensor(self.a, self.a.clone())
+
+        return None
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5162ba0d6cb6729534ab28f8a84a906f8c99f87
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py
@@ -0,0 +1,194 @@
+# mypy: allow-untyped-defs
+import contextlib
+from pathlib import Path
+from typing import Optional
+
+import torch
+
+
+_TORCHBIND_IMPLS_INITIALIZED = False
+
+_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
+
+
+def init_torchbind_implementations():
+    global _TORCHBIND_IMPLS_INITIALIZED
+    global _TENSOR_QUEUE_GLOBAL_TEST
+    if _TORCHBIND_IMPLS_INITIALIZED:
+        return
+
+    load_torchbind_test_lib()
+    register_fake_operators()
+    register_fake_classes()
+    _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
+    _TORCHBIND_IMPLS_INITIALIZED = True
+
+
+def _empty_tensor_queue() -> torch.ScriptObject:
+    return torch.classes._TorchScriptTesting._TensorQueue(
+        torch.empty(
+            0,
+        ).fill_(-1)
+    )
+
+
+# put these under a function because the corresponding library might not be loaded yet.
+def register_fake_operators():
+    @torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
+    def fake_takes_foo(foo, z):
+        return foo.add_tensor(z)
+
+    @torch.library.register_fake("_TorchScriptTesting::queue_pop")
+    def fake_queue_pop(tq):
+        return tq.pop()
+
+    @torch.library.register_fake("_TorchScriptTesting::queue_push")
+    def fake_queue_push(tq, x):
+        return tq.push(x)
+
+    torch.library.register_autocast(
+        "_TorchScriptTesting::queue_push", "cpu", torch.float32
+    )
+    torch.library.register_autocast(
+        "_TorchScriptTesting::queue_push", "cuda", torch.float32
+    )
+
+    torch.library.register_autocast(
+        "_TorchScriptTesting::queue_pop", "cpu", torch.float32
+    )
+    torch.library.register_autocast(
+        "_TorchScriptTesting::queue_pop", "cuda", torch.float32
+    )
+
+    @torch.library.register_fake("_TorchScriptTesting::queue_size")
+    def fake_queue_size(tq):
+        return tq.size()
+
+    def meta_takes_foo_list_return(foo, x):
+        a = foo.add_tensor(x)
+        b = foo.add_tensor(a)
+        c = foo.add_tensor(b)
+        return [a, b, c]
+
+    def meta_takes_foo_tuple_return(foo, x):
+        a = foo.add_tensor(x)
+        b = foo.add_tensor(a)
+        return (a, b)
+
+    @torch.library.register_fake("_TorchScriptTesting::takes_foo_tensor_return")
+    def meta_takes_foo_tensor_return(foo, x):
+        # This implementation deliberately creates unbacked symint for testing
+        ctx = torch.library.get_ctx()
+        fake_shape = [ctx.new_dynamic_size() for _ in range(2)]
+        return torch.empty(fake_shape, dtype=torch.int, device="cpu")
+
+    torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl(
+        torch._C.DispatchKey.Meta
+    )(meta_takes_foo_list_return)
+
+    torch.ops._TorchScriptTesting.takes_foo_tuple_return.default.py_impl(
+        torch._C.DispatchKey.Meta
+    )(meta_takes_foo_tuple_return)
+
+    torch.ops._TorchScriptTesting.takes_foo.default.py_impl(torch._C.DispatchKey.Meta)(
+        # make signature match original cpp implementation to support kwargs
+        lambda foo, x: foo.add_tensor(x)
+    )
+
+
+def register_fake_classes():
+    # noqa: F841
+    @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
+    class FakeFoo:
+        def __init__(self, x: int, y: int):
+            self.x = x
+            self.y = y
+
+        @classmethod
+        def __obj_unflatten__(cls, flattend_foo):
+            return cls(**dict(flattend_foo))
+
+        def add_tensor(self, z):
+            return (self.x + self.y) * z
+
+    @torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
+    class FakeContainsTensor:
+        def __init__(self, t: torch.Tensor):
+            self.t = t
+
+        @classmethod
+        def __obj_unflatten__(cls, flattend_foo):
+            return cls(**dict(flattend_foo))
+
+        def get(self):
+            return self.t
+
+    @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
+    class FakeTensorQueue:
+        def __init__(self, queue):
+            self.queue = queue
+
+        @classmethod
+        def __obj_unflatten__(cls, flattened_ctx):
+            return cls(**dict(flattened_ctx))
+
+        def push(self, x):
+            self.queue.append(x)
+
+        def pop(self):
+            if self.is_empty():
+                return torch.empty([])
+            return self.queue.pop(0)
+
+        def size(self):
+            return len(self.queue)
+
+        def is_empty(self):
+            return len(self.queue) == 0
+
+        def float_size(self):
+            return float(len(self.queue))
+
+    @torch._library.register_fake_class("_TorchScriptTesting::_FlattenWithTensorOp")
+    class FakeFlatten:
+        def __init__(self, t):
+            self.t = t
+
+        def get(self):
+            return self.t
+
+        @classmethod
+        def __obj_unflatten__(cls, flattened_ctx):
+            return cls(**dict(flattened_ctx))
+
+
+def load_torchbind_test_lib():
+    import unittest
+
+    from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
+        find_library_location,
+        IS_FBCODE,
+        IS_MACOS,
+        IS_SANDCASTLE,
+        IS_WINDOWS,
+    )
+
+    if IS_MACOS:
+        raise unittest.SkipTest("non-portable load_library call used in test")
+    elif IS_SANDCASTLE or IS_FBCODE:
+        lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations")
+    elif IS_WINDOWS:
+        lib_file_path = find_library_location("torchbind_test.dll")
+    else:
+        lib_file_path = find_library_location("libtorchbind_test.so")
+    torch.ops.load_library(str(lib_file_path))
+
+
+@contextlib.contextmanager
+def _register_py_impl_temporarily(op_overload, key, fn):
+    try:
+        op_overload.py_impl(key)(fn)
+        yield
+    finally:
+        del op_overload.py_kernels[key]
+        op_overload._dispatch_cache.clear()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0fcbaee30f52a9a0d0f7e72aeaf99582d49f1e0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py
@@ -0,0 +1,1043 @@
+# mypy: ignore-errors
+
+import unittest
+
+from torch.testing._internal.inductor_utils import (
+    HAS_CUDA_AND_TRITON,
+    HAS_GPU,
+    HAS_XPU_AND_TRITON,
+)
+from torch.utils._triton import has_triton
+
+
+requires_cuda_and_triton = unittest.skipUnless(
+    HAS_CUDA_AND_TRITON, "requires cuda and triton"
+)
+requires_gpu_and_triton = unittest.skipUnless(
+    HAS_XPU_AND_TRITON or HAS_CUDA_AND_TRITON, "requires gpu and triton"
+)
+requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
+
+if has_triton():
+    import triton
+    from triton import language as tl
+
+    import torch
+
+    def _get_strange_configs() -> list[triton.Config]:
+        if torch.version.hip:
+            configs = [
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 16,
+                        "BLOCK_SIZE_N": 16,
+                        "BLOCK_SIZE_K": 16,
+                        "GROUP_SIZE_M": 4,
+                        "matrix_instr_nonkdim": 16,
+                        "waves_per_eu": 3,
+                        "kpack": 2,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 128,
+                        "BLOCK_SIZE_N": 64,
+                        "BLOCK_SIZE_K": 16,
+                        "GROUP_SIZE_M": 4,
+                        "matrix_instr_nonkdim": 16,
+                        "waves_per_eu": 3,
+                        "kpack": 2,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+            ]
+        else:
+            configs = [
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 16,
+                        "BLOCK_SIZE_N": 16,
+                        "BLOCK_SIZE_K": 16,
+                        "GROUP_SIZE_M": 4,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 128,
+                        "BLOCK_SIZE_N": 64,
+                        "BLOCK_SIZE_K": 32,
+                        "GROUP_SIZE_M": 8,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+            ]
+        return configs
+
+    # Define here so that multiple tests can take advantage of it
+    @triton.jit
+    def add_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def sub_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x - y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_optional_param(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        ARGS_PASSED: "tl.constexpr",
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        if ARGS_PASSED == "two":
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_none_param_and_equal_to_1_arg(
+        in_ptr0,
+        in_ptr1,  # in_ptr1 could be None
+        out_ptr,
+        n_elements,
+        stride,
+        ARGS_PASSED: "tl.constexpr",
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets * stride, mask=mask)
+        if ARGS_PASSED == "two":
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets * stride, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def sub_kernel_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x - y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_autotuned_weird_param_order(
+        in_ptr0,
+        in_ptr1,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+        out_ptr,
+    ):
+        # out_ptr is after an autotuned param that's declared as tl.constexpr.
+        # This param ordering can create bugs if not handled correctly.
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config(
+                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
+            ),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_2d_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        x_elements,
+        y_elements,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        xoffset = tl.program_id(0) * BLOCK_SIZE_X
+        xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
+        xmask = xindex < x_elements
+        yoffset = tl.program_id(1) * BLOCK_SIZE_Y
+        yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
+        ymask = yindex < y_elements
+        x1 = xindex
+        y0 = yindex
+        tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
+        tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
+        tmp2 = tmp0 + tmp1
+        tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
+
+    def _dummy_early_config_prune(configs, *_, **__):
+        return configs
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+        warmup=10,
+        rep=20,
+        prune_configs_by={"early_config_prune": _dummy_early_config_prune},
+    )
+    @triton.jit
+    def add_kernel_autotuned_with_unsupported_args(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_scaling(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        scaling_factor,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = (x + y) * scaling_factor
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_tma_1d_old_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        offset = pid * BLOCK_SIZE
+
+        a = tl._experimental_descriptor_load(
+            in_desc_ptr0,
+            [offset],
+            [BLOCK_SIZE],
+            tl.float32,
+        )
+        b = tl._experimental_descriptor_load(
+            in_desc_ptr1,
+            [offset],
+            [BLOCK_SIZE],
+            tl.float32,
+        )
+
+        output = a + b
+
+        tl._experimental_descriptor_store(
+            out_desc_ptr,
+            output,
+            [offset],
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_2d_old_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE_X
+        offset_y = pid_y * BLOCK_SIZE_Y
+
+        x = tl._experimental_descriptor_load(
+            in_desc_ptr0,
+            [offset_x, offset_y],
+            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
+            tl.float32,
+        )
+        y = tl._experimental_descriptor_load(
+            in_desc_ptr1,
+            [offset_x, offset_y],
+            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
+            tl.float32,
+        )
+
+        output = x + y
+
+        tl._experimental_descriptor_store(
+            out_desc_ptr,
+            output,
+            [offset_x, offset_y],
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_1d_new_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        offset = pid * BLOCK_SIZE
+
+        a = tl.load_tensor_descriptor(
+            in_desc_ptr0,
+            [offset],
+        )
+        b = tl.load_tensor_descriptor(
+            in_desc_ptr1,
+            [offset],
+        )
+
+        output = a + b
+
+        tl.store_tensor_descriptor(
+            out_desc_ptr,
+            [offset],
+            output,
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_2d_new_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE_X
+        offset_y = pid_y * BLOCK_SIZE_Y
+
+        x = tl.load_tensor_descriptor(
+            in_desc_ptr0,
+            [offset_x, offset_y],
+        )
+        y = tl.load_tensor_descriptor(
+            in_desc_ptr1,
+            [offset_x, offset_y],
+        )
+
+        output = x + y
+
+        tl.store_tensor_descriptor(
+            out_desc_ptr,
+            [offset_x, offset_y],
+            output,
+        )
+
+    @triton.jit
+    def add_kernel_on_device_tma_old_api(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        m,
+        n,
+        workspace,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        a_desc_ptr = workspace
+        b_desc_ptr = workspace + 128
+        c_desc_ptr = workspace + 256
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=a_desc_ptr,
+            global_address=a_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=a_ptr.dtype.element_ty,
+        )
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=b_desc_ptr,
+            global_address=b_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=b_ptr.dtype.element_ty,
+        )
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=c_desc_ptr,
+            global_address=c_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=c_ptr.dtype.element_ty,
+        )
+
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
+
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE
+        offset_y = pid_y * BLOCK_SIZE
+
+        # Load data using the tensor descriptors
+        a = tl._experimental_descriptor_load(
+            a_desc_ptr,
+            [offset_x, offset_y],
+            [BLOCK_SIZE, BLOCK_SIZE],
+            tl.float32,
+        )
+        b = tl._experimental_descriptor_load(
+            b_desc_ptr,
+            [offset_x, offset_y],
+            [BLOCK_SIZE, BLOCK_SIZE],
+            tl.float32,
+        )
+
+        # Perform addition
+        output = a + b
+
+        # Store the result
+        tl._experimental_descriptor_store(
+            c_desc_ptr,
+            output,
+            [offset_x, offset_y],
+        )
+
+    @triton.jit
+    def add_kernel_on_device_tma_new_api(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        m,
+        n,
+        workspace,  # unused but left here to match the old API kernel
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        # Create tensor descriptors using the new API
+        a_desc = tl.make_tensor_descriptor(
+            base=a_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+        b_desc = tl.make_tensor_descriptor(
+            base=b_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+        c_desc = tl.make_tensor_descriptor(
+            base=c_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE
+        offset_y = pid_y * BLOCK_SIZE
+
+        # Load data using the tensor descriptors with the new API
+        a = tl.load_tensor_descriptor(
+            a_desc,
+            [offset_x, offset_y],
+        )
+        b = tl.load_tensor_descriptor(
+            b_desc,
+            [offset_x, offset_y],
+        )
+
+        # Perform addition
+        output = a + b
+
+        # Store the result with the new API
+        tl.store_tensor_descriptor(
+            c_desc,
+            [offset_x, offset_y],
+            output,
+        )
+
+    @triton.jit
+    def mul2_kernel(
+        in_ptr0,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        output = 2 * x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def mul2_inplace_kernel(
+        ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(ptr + offsets, mask=mask)
+        output = 2 * x
+        tl.store(ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def zero_negs(x):
+        return tl.where(x >= 0, x, 0)
+
+    @triton.jit
+    def indirection_kernel(
+        in_ptr0,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+        ACTIVATION: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        if ACTIVATION == "mul2_inplace_kernel":
+            mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
+        elif ACTIVATION == "add_kernel":
+            add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        tl.store(out_ptr + offsets, x, mask=mask)
+
+    @triton.jit
+    def double_strided_kernel(
+        in_ptr,
+        out_ptr,
+        in_y_stride,
+        out_y_stride,
+        X_BLOCK_SIZE: "tl.constexpr",
+        Y_BLOCK_SIZE: "tl.constexpr",
+    ):
+        xid = tl.program_id(axis=0)
+        yid = tl.program_id(axis=1)
+        x_start = xid * X_BLOCK_SIZE
+        y_start = yid * Y_BLOCK_SIZE
+        x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
+        y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
+        src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
+        dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
+        src = tl.load(in_ptr + src_offsets)
+        tl.store(out_ptr + dst_offsets, src * 2.0)
+
+    @triton.jit
+    def inline_asm_kernel_is_pure_true(
+        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
+    ):
+        x = tl.load(X + tl.arange(0, BLOCK))
+        y = tl.load(Y + tl.arange(0, BLOCK))
+        s = tl.full([BLOCK], n, tl.int32)
+        z = tl.inline_asm_elementwise(
+            "shf.l.wrap.b32 $0, $1, $2, $3;",
+            "=r,r, r, r",
+            [x, y, s],
+            dtype=tl.int32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(Z + tl.arange(0, BLOCK), z)
+
+    @triton.jit
+    def inline_asm_kernel_is_pure_false(
+        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
+    ):
+        x = tl.load(X + tl.arange(0, BLOCK))
+        y = tl.load(Y + tl.arange(0, BLOCK))
+        s = tl.full([BLOCK], n, tl.int32)
+        z = tl.inline_asm_elementwise(
+            "shf.l.wrap.b32 $0, $1, $2, $3;",
+            "=r,r, r, r",
+            [x, y, s],
+            dtype=tl.int32,
+            is_pure=False,
+            pack=1,
+        )
+        tl.store(Z + tl.arange(0, BLOCK), z)
+
+    @triton.jit
+    def add_kernel_with_block_ptr(
+        x_ptr,
+        y_ptr,
+        output_ptr,
+        n_elements,
+        BLOCK_SIZE: tl.constexpr,
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        x = tl.load(
+            tl.make_block_ptr(
+                base=x_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            boundary_check=[0],
+        )
+        y = tl.load(
+            tl.make_block_ptr(
+                base=y_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            boundary_check=[0],
+        )
+        output = x + y
+        tl.store(
+            tl.make_block_ptr(
+                base=output_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            output,
+            boundary_check=[0],
+        )
+
+    @triton.jit
+    def kernel_with_block_ptr_2d(
+        x_ptr,
+        output_ptr,
+        n_elements,
+        BLOCK_SIZE: tl.constexpr,
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        x = tl.load(
+            tl.make_block_ptr(
+                base=x_ptr,
+                shape=[n_elements, 1],
+                strides=[1, 1],
+                offsets=[block_start, 0],
+                block_shape=[BLOCK_SIZE, 1],
+                order=[1, 0],
+            ),
+            boundary_check=[0],
+        )
+        output = x
+        tl.store(
+            tl.make_block_ptr(
+                base=output_ptr,
+                shape=[n_elements, 1],
+                strides=[1, 1],
+                offsets=[block_start, 0],
+                block_shape=[BLOCK_SIZE, 1],
+                order=[1, 0],
+            ),
+            output,
+            boundary_check=[0],
+        )
+
+    from triton.language import load, store
+
+    @triton.jit
+    def add_kernel_with_import(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = load(in_ptr0 + offsets, mask=mask)
+        y = load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def cond_op_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        if tl.program_id(0) == 0:
+            output = x + y
+        else:
+            output = x * y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def atomic_add_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.atomic_add(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_4_times_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        for _ in range(2):
+            output = x + y
+            tl.store(out_ptr + offsets, output, mask=mask)
+        i = 2
+        while i > 0:
+            i -= 1
+            output = x + y
+            tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_out_of_order_fn2(
+        in_ptr0,
+        in_ptr1,
+        n_elements,
+        out_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config(
+                {
+                    "BLOCK_SIZE_M": 16,
+                    "BLOCK_SIZE_N": 16,
+                    "BLOCK_SIZE_K": 16,
+                    "GROUP_SIZE_M": 4,
+                },
+                num_stages=4,
+                num_warps=4,
+            ),
+            triton.Config(
+                {
+                    "BLOCK_SIZE_M": 128,
+                    "BLOCK_SIZE_N": 64,
+                    "BLOCK_SIZE_K": 32,
+                    "GROUP_SIZE_M": 8,
+                },
+                num_stages=4,
+                num_warps=4,
+            ),
+        ],
+        key=["M_ptr", "N", "K"],
+    )
+    @triton.jit
+    def strange_config_matmul_kernel(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        M_ptr,
+        N,
+        K,
+        BLOCK_SIZE_M: tl.constexpr,
+        BLOCK_SIZE_N: tl.constexpr,
+        BLOCK_SIZE_K: tl.constexpr,
+        GROUP_SIZE_M: tl.constexpr,
+    ):
+        # This is a simplified matmul from Triton tutorial.
+        pid = tl.program_id(axis=0)
+        M = tl.load(M_ptr)
+        if M == 0 and BLOCK_SIZE_M > 32:
+            # This will run the full matmul if BLOCK_SIZE_M > 32
+            M = 4096
+        elif M == 0:
+            # This directly returns, which will cut short the bad config of 16-block size.
+            return
+        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+        num_pid_in_group = GROUP_SIZE_M * num_pid_n
+        group_id = pid // num_pid_in_group
+        first_pid_m = group_id * GROUP_SIZE_M
+        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+        pid_n = (pid % num_pid_in_group) // group_size_m
+
+        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+        offs_k = tl.arange(0, BLOCK_SIZE_K)
+        a_ptrs = a_ptr + (offs_am[:, None] + offs_k[None, :])
+        b_ptrs = b_ptr + (offs_k[:, None] + offs_bn[None, :])
+
+        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+        for k in range(tl.cdiv(K, BLOCK_SIZE_K)):
+            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
+            accumulator = tl.dot(a, b, accumulator)
+            a_ptrs += BLOCK_SIZE_K
+            b_ptrs += BLOCK_SIZE_K
+        c = accumulator.to(tl.float16)
+
+        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+        c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :]
+        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+        tl.store(c_ptrs, c, mask=c_mask)
+
+    @triton.jit
+    def kernel_with_docstring_double_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
+        """
+        This kernel contains a triple-quote docstring w/ double quotes.
+        Make sure that codegen sanitizes the docstring.
+        """
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
+        tl.store(out_ptr + offsets, ones, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_with_docstring_single_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
+        '''
+        This kernel contains a triple-quote docstring w/ single quotes
+        Make sure that codegen sanitizes the docstring.
+        To prevent it from being linted to double quotes: """!!!"""
+        '''
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
+        tl.store(out_ptr + offsets, ones, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_inline_asm_double_quotes(
+        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
+    ):
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        data = tl.load(in_ptr + offsets, mask=offsets < numel)
+        cos_pow = tl.inline_asm_elementwise(
+            asm="""
+            {
+                cos.approx.f32 $0, $1;
+                ex2.approx.f32 $0, $0;
+            }
+                """,
+            constraints=("=r, r"),
+            args=[data],
+            dtype=tl.float32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_inline_asm_single_quotes(
+        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
+    ):
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        data = tl.load(in_ptr + offsets, mask=offsets < numel)
+        cos_pow = tl.inline_asm_elementwise(
+            asm='''
+            {
+                // double quotes to pacify the linter """!!!"""
+                cos.approx.f32 $0, $1;
+                ex2.approx.f32 $0, $0;
+            }
+                ''',
+            constraints=("=r, r"),
+            args=[data],
+            dtype=tl.float32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
+
+    @triton.jit
+    def add_kernel_with_boolean_param(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        add_xy,  # boolean param
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        if add_xy:
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    # support the old (experimental) and new (tensor_descriptor) APIs
+    def create_tensor_descriptor_shim(
+        tensor, block_sizes: list[int], new_api: bool = True
+    ):
+        if new_api:
+            return triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
+                tensor, block_sizes
+            )
+        else:
+            if len(block_sizes) == 1:
+                return triton.tools.experimental_descriptor.create_1d_tma_descriptor(
+                    tensor.data_ptr(),
+                    tensor.size(0),
+                    block_sizes[0],
+                    tensor.element_size(),
+                )
+            else:
+                assert len(block_sizes) == 2
+                return triton.tools.experimental_descriptor.create_2d_tma_descriptor(
+                    tensor.data_ptr(),
+                    tensor.size(0),
+                    tensor.size(1),
+                    block_sizes[0],
+                    block_sizes[1],
+                    tensor.element_size(),
+                )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8197829ac7f44f38d295995dd921ddf58b30adfd
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py
@@ -0,0 +1,100 @@
+# mypy: ignore-errors
+
+import torch
+import torch.utils._pytree as pytree
+from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
+from torch.utils._python_dispatch import return_and_correct_aliasing
+
+
+# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
+class TwoTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
+        if outer_size is None:
+            outer_size = a.size()
+        if outer_stride is None:
+            outer_stride = a.stride()
+
+        assert (
+            a.device == b.device
+            and a.layout == b.layout
+            and a.requires_grad == b.requires_grad
+            and a.dtype == b.dtype
+        )
+        # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
+        shape = outer_size
+        kwargs = {}
+        kwargs["strides"] = outer_stride
+        kwargs["storage_offset"] = a.storage_offset()
+        kwargs["device"] = a.device
+        kwargs["layout"] = a.layout
+        kwargs["requires_grad"] = requires_grad or a.requires_grad
+        kwargs["dtype"] = a.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+
+        assert a.shape == b.shape
+        assert a.stride() == b.stride()
+        assert a.storage_offset() == b.storage_offset()
+        return out
+
+    @torch._disable_dynamo
+    @mark_subclass_constructor_exportable_experimental
+    def __init__(self, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
+        self.a = a
+        self.b = b
+
+    def __repr__(self):
+        a_repr = repr(self.a)
+        b_repr = repr(self.b)
+        return f"TwoTensor({a_repr}, {b_repr})"
+
+    def __tensor_flatten__(self):
+        return ["a", "b"], None
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is None
+        a, b = inner_tensors["a"], inner_tensors["b"]
+        if type(a) is torch.Tensor:
+            assert outer_size is not None
+            assert outer_stride is not None
+        return TwoTensor(a, b, outer_size, outer_stride)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+        args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
+        args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
+
+        kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
+        kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
+
+        out_a = func(*args_a, **kwargs_a)
+        out_b = func(*args_b, **kwargs_b)
+        out_a_flat, spec = pytree.tree_flatten(out_a)
+        out_b_flat = pytree.tree_leaves(out_b)
+        # for aten ops that return non-tensors, just assume that
+        # our two inner tensors return the same value
+        out_flat = [
+            cls(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
+            for o_a, o_b in zip(out_a_flat, out_b_flat, strict=True)
+        ]
+        out = pytree.tree_unflatten(out_flat, spec)
+        from torch._higher_order_ops.cond import cond_op
+
+        if func is cond_op:
+            return out
+        else:
+            return return_and_correct_aliasing(func, args, kwargs, out)
+
+    def get_elem_a(self):
+        return self.a
+
+
+class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        out = func(*args, **kwargs)
+        if torch._subclasses.fake_tensor._is_tensor_constructor(func):
+            out = TwoTensor(out, out.clone())
+        return out
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..623647043fd8e8cc0dadb930a97058a1072bbaf1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69a6f99e95f0689f4748a91d44720fa9b0e19786
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fc6dd821c89a4db3525ea2af31709a318bf19f9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9876b3b82f473f35d31b9e0a1a78fc05a6bcce48
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75cba762f694ec6964baac7c9649e711f89ae086
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee43ce43c7ae8f5f618ad6e1b08642bc2a12bf50
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_filelock.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_filelock.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92fc6ca1950e7057a7b661f8f7574630b8485a2c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_filelock.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac8e62bd8de7d8dacafcfad733de91a83492239d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..526c5d648571f6934f5be186e2e4977d506ff857
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_runtime_estimation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_runtime_estimation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03f96a8c000953238b4e269ddd4aef762d075c54
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_runtime_estimation.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e595dad0af4bf89aa4fccb46cf9598c50e26a32
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5240c9e96413673429bffbb87c4020cd6c1e1bd8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e83ef2a6bba1a7fc86eb9816d4877c70b269161
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9dbfee35f927f4a25418f15633d5b1d79ca8746
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..286a7436dc6f55968b878f07cc081581c2192178
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92d296f7eddab413bb5a1651ecd9fc961de14d7b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/file_baton.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/file_baton.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0753e6f3101d3454c49d93f9c1880b3f6f6b1be5
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/file_baton.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/show_pickle.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/show_pickle.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42d2bfe0bd7397aec03beed54e9814dfbba53bd2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/show_pickle.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3af82fc2a601b4ff1f932f7013547d2a518ae5e3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c808629b1f52fc0ae01bde08f14a867908442e0a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c16825894f59d1ffcca3f486f99416f99f93360a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e1595bf2a1477b33ed00446d86e6cdea267a8f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py
@@ -0,0 +1,313 @@
+# mypy: disallow-untyped-defs
+
+import functools
+import logging
+import os
+import re
+import subprocess
+import time
+from collections.abc import Callable, Sequence
+from threading import Lock
+from typing import Any, TypeVar
+from typing_extensions import ParamSpec
+
+
+logger = logging.getLogger("strobelight_function_profiler")
+
+console_handler = logging.StreamHandler()
+formatter = logging.Formatter(
+    "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
+)
+console_handler.setFormatter(formatter)
+
+logger.addHandler(console_handler)
+logger.setLevel(logging.INFO)
+logger.propagate = False
+
+_P = ParamSpec("_P")
+_R = TypeVar("_R")
+
+
+class StrobelightCLIProfilerError(Exception):
+    """
+    Raised when an error happens during strobelight profiling
+    """
+
+
+def _pid_namespace_link(pid: int | None = None) -> str:
+    """Returns the link to the process's namespace, example: pid:[4026531836]"""
+    PID_NAMESPACE_PATH = "/proc/{}/ns/pid"
+    pid = pid or os.getpid()
+    return os.readlink(PID_NAMESPACE_PATH.format(pid))
+
+
+def _pid_namespace(pid: int | None = None) -> int:
+    """Returns the process's namespace id"""
+    pid = pid or os.getpid()
+    link = _pid_namespace_link(pid)
+    return int(link[link.find("[") + 1 : -1])
+
+
+def _command_to_string(command: Sequence[str]) -> str:
+    return " ".join(command)
+
+
+class StrobelightCLIFunctionProfiler:
+    """
+    Note: this is a meta only tool.
+
+    StrobelightCLIFunctionProfiler can be used to profile a python function and
+    generate a strobelight link with the results. It works on meta servers but
+    does not requires an fbcode target.
+    When stop_at_error is false(default), error during profiling does not prevent
+    the work function from running.
+
+    Check function_profiler_example.py for an example.
+    """
+
+    # This lock is used to make sure only one thread is running the profiler at any point.
+    _lock = Lock()
+
+    def __init__(
+        self,
+        *,
+        stop_at_error: bool = False,
+        max_profile_duration_sec: int = 60 * 10,
+        sample_each: float = 1e7,  # sample each sample_each cycles.
+        run_user_name: str = "pytorch-strobelight-ondemand",
+        timeout_wait_for_running_sec: int = 60,
+        timeout_wait_for_finished_sec: int = 60,
+        recorded_env_variables: list[str] | None = None,
+        sample_tags: list[str] | None = None,
+        stack_max_len: int = 127,
+        async_stack_max_len: int = 127,
+    ) -> None:
+        self.stop_at_error = stop_at_error
+        self.max_profile_duration_sec = max_profile_duration_sec
+        self.sample_each = sample_each
+        self.run_user_name = run_user_name
+        self.timeout_wait_for_running_sec = timeout_wait_for_running_sec
+        self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec
+        # Results of the most recent run.
+        # Tracks the strobelight run id of the most recent run
+        self.current_run_id: int | None = None
+        self.sample_tags = sample_tags
+
+    def _run_async(self) -> None:
+        processId = os.getpid()
+        namespace = _pid_namespace(processId)
+        command = [
+            "strobeclient",
+            "run",
+            "--profiler",
+            "pyperf",
+            "--event",
+            "cycles",
+            "--async",
+            "--sample-interval",
+            f"{int(self.sample_each)}",
+            "--duration-ms",
+            f"{int(self.max_profile_duration_sec * 1000)}",
+            "--pid",
+            f"{namespace}:{processId}",
+        ]
+
+        if self.sample_tags:
+            command.append("--sample-tags")
+            command.append(",".join(self.sample_tags))
+
+        logger.debug("running command: %s", _command_to_string(command))
+        result = subprocess.run(command, capture_output=True)
+        output = result.stderr.decode("utf-8")
+        logger.debug("output:\n{%s}", output)
+
+        if result.returncode != 0:
+            raise StrobelightCLIProfilerError(
+                f"failed to start strobelight profiling, error in run_async:{output}"
+            )
+
+        if match := re.search(r"INFO Run Id: (-?\d+)", output):
+            self.current_run_id = int(match.group(1))
+            return
+
+        raise StrobelightCLIProfilerError(
+            f"failed to start strobelight profiling, unexpected result {output}"
+        )
+
+    def _wait_for_running(self, counter: int = 0) -> None:
+        if counter > 20:
+            raise StrobelightCLIProfilerError(
+                "wait_for_running called more than 20 times"
+            )
+
+        command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"]
+        logger.debug("running command: %s", _command_to_string(command))
+        result = subprocess.run(command, capture_output=True)
+        output = result.stderr.decode("utf-8")
+        logger.debug("output:\n{%s}", output)
+
+        if result.returncode != 0:
+            raise StrobelightCLIProfilerError(
+                f"failed to start strobelight profiling, error in wait_for_running:{output}"
+            )
+
+        if match := re.search("Profile run status: (.*)", output):
+            current_status = match.group(1)
+            if current_status == "RUNNING":
+                return
+            elif current_status == "PREPARING":
+                time.sleep(10)
+                self._wait_for_running(counter + 1)
+                return
+            else:
+                raise StrobelightCLIProfilerError(f"unexpected {current_status} phase")
+
+        raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
+
+    def _stop_run(self) -> None:
+        command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)]
+        logger.debug("running command: %s", _command_to_string(command))
+        result = subprocess.run(command, capture_output=True)
+        output = result.stderr.decode("utf-8")
+        logger.debug("output:\n{%s}", output)
+
+        if result.returncode != 0:
+            raise StrobelightCLIProfilerError(
+                f"failed to stop strobelight profiling, return code is not 0 :{output}"
+            )
+
+        if match := re.search("INFO ::1:(.*)", output):
+            current_status = match.group(1)
+            if current_status.__contains__("Success!"):
+                return
+            else:
+                raise StrobelightCLIProfilerError(
+                    f"failed to stop strobelight profiling, got {current_status} result"
+                )
+
+        raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
+
+    def _get_results(self) -> None:
+        command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)]
+        logger.debug("running command: %s", _command_to_string(command))
+        result = subprocess.run(command, capture_output=True)
+        output = result.stderr.decode("utf-8")
+        logger.debug("output:\n{%s}", output)
+
+        if result.returncode != 0:
+            raise StrobelightCLIProfilerError(
+                f"failed to extract profiling results, return code is not 0 : {output}"
+            )
+
+        if match := re.search("INFO ::1:(.*)", output):
+            current_status = match.group(1)
+            if current_status.__contains__("Profile run status: PROCESSING"):
+                time.sleep(10)
+                self._get_results()
+                return
+            elif not current_status.__contains__("Profile run finished with SUCCESS"):
+                raise StrobelightCLIProfilerError(
+                    f"failed to extract profiling results, unexpected response {output}"
+                )
+
+        for item in re.findall(
+            r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))",
+            output,
+        ):
+            logger.info(item[0])
+
+    def _stop_strobelight_no_throw(
+        self,
+        collect_results: bool,
+    ) -> None:
+        try:
+            # call stop run
+            self._stop_run()
+            logger.info("strobelight profiling stopped")
+
+            logger.debug("collection stopped")
+
+            if not collect_results:
+                return
+
+            self._get_results()
+        except Exception:
+            logger.warning("error during stop_strobelight", exc_info=True)
+
+    # Return true if strobelight started and is running. Never throw.
+    def _start_strobelight(self) -> bool:
+        strobelight_started = False
+        try:
+            self._run_async()
+            strobelight_started = True
+            logger.info("strobelight run id is: %s", self.current_run_id)
+            self._wait_for_running()
+            logger.info("strobelight profiling running")
+            return True
+
+        except Exception:
+            logger.warning("error during start_strobelight:", exc_info=True)
+            if strobelight_started:
+                self._stop_strobelight_no_throw(collect_results=False)
+            return False
+
+    def profile(
+        self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
+    ) -> _R | None:
+        self.current_run_id = None
+
+        if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
+            if not locked:
+                if self.stop_at_error:
+                    raise StrobelightCLIProfilerError("concurrent runs not supported")
+
+                logger.warning("concurrent runs not supported")
+                return work_function(*args, **kwargs)
+
+            started = self._start_strobelight()
+            if not started:
+                if self.stop_at_error:
+                    StrobelightCLIFunctionProfiler._lock.release()
+                    raise StrobelightCLIProfilerError(
+                        "failed to start strobelight profiling"
+                    )
+                result = work_function(*args, **kwargs)
+                StrobelightCLIFunctionProfiler._lock.release()
+                return result
+
+            try:
+                logger.debug("collection started")
+                result = work_function(*args, **kwargs)
+                self._stop_strobelight_no_throw(collect_results=True)
+                StrobelightCLIFunctionProfiler._lock.release()
+                return result
+            except Exception as error:
+                logger.warning("work function throw exception", exc_info=True)
+                self._stop_strobelight_no_throw(collect_results=False)
+                StrobelightCLIFunctionProfiler._lock.release()
+                raise error
+        return None
+
+
+# A function decorator that wraps profile, if no profiler is provided one with
+# default args is created. A function can be annotated as:
+# @strobelight()
+# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..))
+# @strobelight(stop_at_error=True,...)
+def strobelight(
+    profiler: StrobelightCLIFunctionProfiler | None = None, **kwargs: Any
+) -> Callable[[Callable[_P, _R]], Callable[_P, _R | None]]:
+    if not profiler:
+        profiler = StrobelightCLIFunctionProfiler(**kwargs)
+
+    def strobelight_inner(
+        work_function: Callable[_P, _R],
+    ) -> Callable[_P, _R | None]:
+        @functools.wraps(work_function)
+        def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _R | None:
+            # pyrefly: ignore [bad-argument-type]
+            return profiler.profile(work_function, *args, **kwargs)
+
+        return wrapper_function
+
+    return strobelight_inner
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9ee47c7156e8ead9d60ff60952102d5ae77e387
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55950dc4465aa9768a6ca4503338895b58bc962e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a2e766cf105f6a822362053dcf705365454b3d2
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43c8e57de18215fc30a5f55f7c2da4af6f470869
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3927a191b0e3122df63f3700475de74831260b0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da03a177ee370b0b3c731bc3136b0f1aff9f181a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9241553f121cd517ee5da63954e71c947e51ebb0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e9acf87edc548317ac95ebc26ff21a2d165b5b1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9936fde913f5ff0c0637064df95c5f82a7f865f8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e80fcabc654513e8aad38181d69e1fbdd42e6329
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7c33fe9448be966dba855679ba9e2b9cb23e005
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..736c36022c2e42b22b7c7ec7883260b179a151cc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b1af1feb66b7fa0ef166afb7a51df7e49fbe590
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40412e531d86b2f2891dbb82489ac5a2aa8ee16a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c29e28f64a6b2c326b0fb6239fdde037bce48459
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21d845aa636b620de645857d1d01569e438bf1c8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a9e06c2f3980d5186400be1cb2c38f5ea2b111c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eee47f327cfdb8597d59b1c606019a292b601c7f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c266e7cf9a6e604c94dfb28f19f31f1649220f4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py
@@ -0,0 +1,99 @@
+# mypy: allow-untyped-defs
+"""Example of Timer and Compare APIs:
+
+$ python -m examples.compare
+"""
+
+import pickle
+import sys
+import time
+
+import torch
+
+import torch.utils.benchmark as benchmark_utils
+
+
+class FauxTorch:
+    """Emulate different versions of pytorch.
+
+    In normal circumstances this would be done with multiple processes
+    writing serialized measurements, but this simplifies that model to
+    make the example clearer.
+    """
+    def __init__(self, real_torch, extra_ns_per_element) -> None:
+        self._real_torch = real_torch
+        self._extra_ns_per_element = extra_ns_per_element
+
+    def extra_overhead(self, result):
+        # time.sleep has a ~65 us overhead, so only fake a
+        # per-element overhead if numel is large enough.
+        numel = int(result.numel())
+        if numel > 5000:
+            time.sleep(numel * self._extra_ns_per_element * 1e-9)
+        return result
+
+    def add(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.add(*args, **kwargs))
+
+    def mul(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.mul(*args, **kwargs))
+
+    def cat(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.cat(*args, **kwargs))
+
+    def matmul(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.matmul(*args, **kwargs))
+
+
+def main() -> None:
+    tasks = [
+        ("add", "add", "torch.add(x, y)"),
+        ("add", "add (extra +0)", "torch.add(x, y + zero)"),
+    ]
+
+    serialized_results = []
+    repeats = 2
+    timers = [
+        benchmark_utils.Timer(
+            stmt=stmt,
+            globals={
+                "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
+                "x": torch.ones((size, 4)),
+                "y": torch.ones((1, 4)),
+                "zero": torch.zeros(()),
+            },
+            label=label,
+            sub_label=sub_label,
+            description=f"size: {size}",
+            env=branch,
+            num_threads=num_threads,
+        )
+        for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)]
+        for label, sub_label, stmt in tasks
+        for size in [1, 10, 100, 1000, 10000, 50000]
+        for num_threads in [1, 4]
+    ]
+
+    for i, timer in enumerate(timers * repeats):
+        serialized_results.append(pickle.dumps(
+            timer.blocked_autorange(min_run_time=0.05)
+        ))
+        print(f"\r{i + 1} / {len(timers) * repeats}", end="")
+        sys.stdout.flush()
+    print()
+
+    comparison = benchmark_utils.Compare([
+        pickle.loads(i) for i in serialized_results
+    ])
+
+    print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
+    comparison.print()
+
+    print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
+    comparison.trim_significant_figures()
+    comparison.colorize()
+    comparison.print()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..80a4e733928d8b059919d847da1b461d55dd7402
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py
@@ -0,0 +1,86 @@
+# mypy: allow-untyped-defs
+"""Example of the Timer and Fuzzer APIs:
+
+$ python -m examples.fuzzer
+"""
+
+import sys
+
+import torch.utils.benchmark as benchmark_utils
+
+
+def main() -> None:
+    add_fuzzer = benchmark_utils.Fuzzer(
+        parameters=[
+            [
+                benchmark_utils.FuzzedParameter(
+                    name=f"k{i}",
+                    minval=16,
+                    maxval=16 * 1024,
+                    distribution="loguniform",
+                ) for i in range(3)
+            ],
+            benchmark_utils.FuzzedParameter(
+                name="d",
+                distribution={2: 0.6, 3: 0.4},
+            ),
+        ],
+        tensors=[
+            [
+                benchmark_utils.FuzzedTensor(
+                    name=name,
+                    size=("k0", "k1", "k2"),
+                    dim_parameter="d",
+                    probability_contiguous=0.75,
+                    min_elements=64 * 1024,
+                    max_elements=128 * 1024,
+                ) for name in ("x", "y")
+            ],
+        ],
+        seed=0,
+    )
+
+    n = 250
+    measurements = []
+    for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)):
+        x, x_order = tensors["x"], str(tensor_properties["x"]["order"])
+        y, y_order = tensors["y"], str(tensor_properties["y"]["order"])
+        shape = ", ".join(tuple(f'{i:>4}' for i in x.shape))
+
+        description = "".join([
+            f"{x.numel():>7} | {shape:<16} | ",
+            f"{'contiguous' if x.is_contiguous() else x_order:<12} | ",
+            f"{'contiguous' if y.is_contiguous() else y_order:<12} | ",
+        ])
+
+        timer = benchmark_utils.Timer(
+            stmt="x + y",
+            globals=tensors,
+            description=description,
+        )
+
+        measurements.append(timer.blocked_autorange(min_run_time=0.1))
+        measurements[-1].metadata = {"numel": x.numel()}
+        print(f"\r{i + 1} / {n}", end="")
+        sys.stdout.flush()
+    print()
+
+    # More string munging to make pretty output.
+    print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}")
+
+    def time_fn(m):
+        return m.median / m.metadata["numel"]
+    measurements.sort(key=time_fn)
+
+    template = f"{{:>6}}{' ' * 19}Size    Shape{' ' * 13}X order        Y order\n{'-' * 80}"
+    print(template.format("Best:"))
+    for m in measurements[:15]:
+        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")
+
+    print("\n" + template.format("Worst:"))
+    for m in measurements[-15:]:
+        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65599ee18a4f2c4a0d35b514c8f87725affae01
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py
@@ -0,0 +1,107 @@
+# mypy: allow-untyped-defs
+"""Example use of Timer and op fuzzers to measure kernel performance.
+
+$ python -m examples.op_benchmark
+"""
+
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Timer
+from torch.utils.benchmark.op_fuzzers.binary import BinaryOpFuzzer
+from torch.utils.benchmark.op_fuzzers.unary import UnaryOpFuzzer
+import operator
+
+
+_MEASURE_TIME = 1.0
+
+
+def assert_dicts_equal(dict_0, dict_1) -> None:
+    """Builtin dict comparison will not compare numpy arrays.
+    e.g.
+        x = {"a": np.ones((2, 1))}
+        x == x  # Raises ValueError
+    """
+    if set(dict_0.keys()) != set(dict_0.keys()):
+        raise AssertionError("dicts must have the same keys")
+    if all(np.all(v != dict_1[k]) for k, v in dict_0.items() if k != "dtype"):
+        raise AssertionError("dict values differ for keys other than 'dtype'")
+
+
+def run(n, stmt, fuzzer_cls) -> None:
+    float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
+    int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n)
+    raw_results = []
+    for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter, strict=True)):
+        float_tensors, float_tensor_params, float_params = float_values
+        int_tensors, int_tensor_params, int_params = int_values
+
+        # This benchmark assumes that the two fuzzers generate identically
+        # sized and strided Tensors, since the same seed is used.
+        assert_dicts_equal(float_params, int_params)
+        assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
+
+        float_measurement, int_measurement = (
+            Timer(
+                stmt,
+                globals=tensors,
+            ).blocked_autorange(min_run_time=_MEASURE_TIME)
+            for tensors in (float_tensors, int_tensors)
+        )
+
+        descriptions = []
+        for name in float_tensors:
+            shape_str = "(" + ", ".join([
+                f"2 ** {int(np.log2(i))}"
+                if 2 ** int(np.log2(i)) == i and i > 1
+                else str(i)
+                for i in float_tensors[name].shape
+            ]) + ")"
+            order = float_tensor_params[name]["order"]
+            order_str = ("" if all(order == np.arange(len(order))) else str(tuple(order)))
+            steps = float_tensor_params[name]["steps"]
+            steps_str = str(steps) if sum(steps) > len(steps) else ""
+            descriptions.append((name, shape_str, order_str, steps_str))
+        raw_results.append((float_measurement, int_measurement, descriptions))
+
+        print(f"\r{i + 1} / {n}", end="")
+    print()
+
+    parsed_results, name_len, shape_len, order_len, steps_len = [], 0, 0, 0, 0
+    for float_measurement, int_measurement, descriptions in raw_results:
+        t_float = float_measurement.median * 1e6
+        t_int = int_measurement.median * 1e6
+        rel_diff = abs(t_float - t_int) / (t_float + t_int) * 2
+        parsed_results.append((t_float, t_int, rel_diff, descriptions))
+        for name, shape, order, steps in descriptions:
+            name_len = max(name_len, len(name))
+            shape_len = max(shape_len, len(shape))
+            order_len = max(order_len, len(order))
+            steps_len = max(steps_len, len(steps))
+
+    parsed_results.sort(key=operator.itemgetter(2))
+
+    print(f"stmt: {stmt}")
+    print(f" diff    faster{'':>17}{' ' * name_len} ", end="")
+    print(f"{'shape'.ljust(shape_len)}{'':>16}{'order'.ljust(order_len)}", end="")
+    print(f"          steps\n{'-' * 100}")
+    for results, spacer in [(parsed_results[:10], "..."), (parsed_results[-10:], "")]:
+        for t_float, t_int, rel_diff, descriptions in results:
+            time_str = [f"{rel_diff * 100:>4.1f}%    {'int' if t_int < t_float else 'float':<20}"]
+            time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
+            for t_str, (name, shape, order, steps) in zip(time_str, descriptions, strict=True):
+                name = f"{name}:".ljust(name_len + 1)
+                shape = shape.ljust(shape_len + 10)
+                order = order.ljust(order_len)
+                print(f"{t_str} {name}  {shape}|     {order}      |   {steps}")
+        print(spacer)
+
+
+def main() -> None:
+    run(n=100, stmt="torch.median(x, dim=0)", fuzzer_cls=UnaryOpFuzzer)
+    run(n=100, stmt="torch.square(x)", fuzzer_cls=UnaryOpFuzzer)
+    run(n=100, stmt="x + y", fuzzer_cls=BinaryOpFuzzer)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8137d4d8791975b46b1314c2f3a05ed048dbdcd3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py
@@ -0,0 +1,25 @@
+"""Trivial use of Timer API:
+
+$ python -m examples.simple_timeit
+"""
+
+import torch
+
+import torch.utils.benchmark as benchmark_utils
+
+
+def main() -> None:
+    timer = benchmark_utils.Timer(
+        stmt="x + y",
+        globals={"x": torch.ones((4, 8)), "y": torch.ones((1, 8))},
+        label="Broadcasting add (4x8)",
+    )
+
+    for i in range(3):
+        print(f"Run: {i}\n{'-' * 40}")
+        print(f"timeit:\n{timer.timeit(10000)}\n")
+        print(f"autorange:\n{timer.blocked_autorange()}\n\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a33c34bc8229a44838ea93c29af34895061c53
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py
@@ -0,0 +1,114 @@
+# mypy: allow-untyped-defs
+"""Microbenchmarks for the torch.fft module"""
+from argparse import ArgumentParser
+from collections import namedtuple
+from collections.abc import Iterable
+
+import torch
+import torch.fft
+from torch.utils import benchmark
+from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer
+
+
+def _dim_options(ndim):
+    if ndim == 1:
+        return [None]
+    elif ndim == 2:
+        return [0, 1, None]
+    elif ndim == 3:
+        return [0, 1, 2, (0, 1), (0, 2), None]
+    raise ValueError(f"Expected ndim in range 1-3, got {ndim}")
+
+
+def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int,
+                  probability_regular: float):
+    cuda = device == 'cuda'
+    spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda,
+                                       probability_regular=probability_regular)
+    results = []
+    for tensors, tensor_params, params in spectral_fuzzer.take(samples):
+        shape = [params['k0'], params['k1'], params['k2']][:params['ndim']]
+        str_shape = ' x '.join([f"{s:<4}" for s in shape])
+        sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
+        for dim in _dim_options(params['ndim']):
+            for nthreads in (1, 4, 16) if not cuda else (1,):
+                measurement = benchmark.Timer(
+                    stmt='func(x, dim=dim)',
+                    globals={'func': function, 'x': tensors['x'], 'dim': dim},
+                    label=f"{name}_{device}",
+                    sub_label=sub_label,
+                    description=f"dim={dim}",
+                    num_threads=nthreads,
+                ).blocked_autorange(min_run_time=1)
+                measurement.metadata = {
+                    'name': name,
+                    'device': device,
+                    'dim': dim,
+                    'shape': shape,
+                }
+                measurement.metadata.update(tensor_params['x'])
+                results.append(measurement)
+    return results
+
+
+Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype'])
+BENCHMARKS = [
+    Benchmark('fft_real', torch.fft.fftn, torch.float32),
+    Benchmark('fft_complex', torch.fft.fftn, torch.complex64),
+    Benchmark('ifft', torch.fft.ifftn, torch.complex64),
+    Benchmark('rfft', torch.fft.rfftn, torch.float32),
+    Benchmark('irfft', torch.fft.irfftn, torch.complex64),
+]
+BENCHMARK_MAP = {b.name: b for b in BENCHMARKS}
+BENCHMARK_NAMES = [b.name for b in BENCHMARKS]
+DEVICE_NAMES = ['cpu', 'cuda']
+
+def _output_csv(file, results) -> None:
+    file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n')
+    for measurement in results:
+        metadata = measurement.metadata
+        device, dim, shape, name, numel, contiguous = (
+            metadata['device'], metadata['dim'], metadata['shape'],
+            metadata['name'], metadata['numel'], metadata['is_contiguous'])
+
+        if isinstance(dim, Iterable):
+            dim_str = '-'.join(str(d) for d in dim)
+        else:
+            dim_str = str(dim)
+            shape_str = 'x'.join(str(s) for s in shape)
+
+        print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str,  # type: ignore[possibly-undefined]
+              measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
+              sep=',', file=file)
+
+
+if __name__ == '__main__':
+    parser = ArgumentParser(description=__doc__)
+    parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES)
+    parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES)
+    parser.add_argument('--seed', type=int, default=0)
+    parser.add_argument('--samples', type=int, default=10)
+    parser.add_argument('--probability-regular', '--probability_regular', type=float, default=1.0)
+    parser.add_argument('-o', '--output', type=str)
+    args = parser.parse_args()
+
+    num_benchmarks = len(args.device) * len(args.bench)
+    i = 0
+    results = []
+    for device in args.device:
+        for bench in (BENCHMARK_MAP[b] for b in args.bench):
+            results += run_benchmark(
+                name=bench.name, function=bench.function, dtype=bench.dtype,
+                seed=args.seed, device=device, samples=args.samples,
+                probability_regular=args.probability_regular)
+            i += 1
+            print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})')
+
+    if args.output is not None:
+        with open(args.output, 'w') as f:
+            _output_csv(f, results)
+
+    compare = benchmark.Compare(results)
+    compare.trim_significant_figures()
+    compare.colorize()
+    compare.print()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2965fce687ded38a4f237bc4abb19d095ea82d94
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dee952b7afe4d8a9338db038dfbbe3622d6d7cf4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84093389e89e60bebafa62f2afb9531949f18a86
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e389ce9b14eb19eb5f2b74cc49819ec03502e87
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d53b7fd73a4aa63fd09c06ddcde81e202fe26e91
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9aa28366404243073fafdc2def7e44b3f4a996db
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e53c310111bec8166e6090f351e39153dbe400aa
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py
@@ -0,0 +1,107 @@
+# mypy: allow-untyped-defs
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor
+
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+
+class BinaryOpFuzzer(Fuzzer):
+    def __init__(self, seed, dtype=torch.float32, cuda=False) -> None:
+        super().__init__(
+            parameters=[
+                # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+
+                # Shapes for `x` and `y`.
+                #       It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                #       Moreover, `y` will occasionally have singleton
+                #   dimensions in order to test broadcasting.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                [
+                    FuzzedParameter(
+                        name=f"y_k{i}",
+                        distribution={
+                            ParameterAlias(f"k{i}"): 0.8,
+                            1: 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                # Steps for `x` and `y`. (Benchmarks strided memory access.)
+                [
+                    FuzzedParameter(
+                        name=f"{name}_step_{i}",
+                        distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04},
+                    )
+                    for i in range(3)
+                    for name in ("x", "y")
+                ],
+
+                # Repeatable entropy for downstream applications.
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    steps=("x_step_0", "x_step_1", "x_step_2"),
+                    probability_contiguous=0.75,
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="dim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+                FuzzedTensor(
+                    name="y",
+                    size=("y_k0", "y_k1", "y_k2"),
+                    steps=("x_step_0", "x_step_1", "x_step_2"),
+                    probability_contiguous=0.75,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="dim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e6269464e0d53d2c3c51ed5406d7c88598fec79
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py
@@ -0,0 +1,107 @@
+# mypy: allow-untyped-defs
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor
+
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+
+class BinaryOpSparseFuzzer(Fuzzer):
+    def __init__(self, seed, dtype=torch.float32, cuda=False) -> None:
+        super().__init__(
+            parameters=[
+                # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+                FuzzedParameter(
+                    name="sparse_dim",
+                    distribution={1: 0.4, 2: 0.4, 3: 0.2},
+                    strict=True
+                ),
+                # Shapes for `x` and `y`.
+                #       It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                #       Moreover, `y` will occasionally have singleton
+                #   dimensions in order to test broadcasting.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"y_k{i}",
+                        distribution={
+                            ParameterAlias(f"k{i}"): 1.0},
+                        strict=True,
+                    ) for i in range(3)
+                ],
+                FuzzedParameter(
+                    name="density",
+                    distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3},
+                ),
+                FuzzedParameter(
+                    name="coalesced",
+                    distribution={True: 0.5, False: 0.5},
+                ),
+                # Repeatable entropy for downstream applications.
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedSparseTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    dim_parameter="dim_parameter",
+                    sparse_dim="sparse_dim",
+                    density="density",
+                    coalesced="coalesced",
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+                FuzzedSparseTensor(
+                    name="y",
+                    size=("y_k0", "y_k1", "y_k2"),
+                    dim_parameter="dim_parameter",
+                    sparse_dim="sparse_dim",
+                    density="density",
+                    coalesced="coalesced",
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py
new file mode 100644
index 0000000000000000000000000000000000000000..18921becd078cb3140a1705078dd57f4a597a2ec
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py
@@ -0,0 +1,92 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+import torch
+
+if TYPE_CHECKING:
+    from torch.types import _dtype
+
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor
+
+__all__ = ["UnaryOpSparseFuzzer"]
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+class UnaryOpSparseFuzzer(Fuzzer):
+    def __init__(self, seed: int | None, dtype: _dtype | None = None, cuda: bool = False) -> None:
+        if dtype is None:
+            dtype = getattr(torch, 'float32', None)
+        super().__init__(
+            parameters=[
+                # Sparse dim parameter of x. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+                FuzzedParameter(
+                    name="sparse_dim",
+                    distribution={1: 0.4, 2: 0.4, 3: 0.2},
+                    strict=True
+                ),
+                # Shapes for `x`.
+                #   It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+                FuzzedParameter(
+                    name="density",
+                    distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3},
+                ),
+                FuzzedParameter(
+                    name="coalesced",
+                    distribution={True: 0.5, False: 0.5},
+                ),
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedSparseTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    dim_parameter="dim_parameter",
+                    sparse_dim="sparse_dim",
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    density="density",
+                    coalesced="coalesced",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py
new file mode 100644
index 0000000000000000000000000000000000000000..c324e338dca5da3d2b8b9a55e7d89f108d6783dd
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py
@@ -0,0 +1,94 @@
+# mypy: allow-untyped-defs
+import math
+
+import torch
+from torch.utils import benchmark
+from torch.utils.benchmark import FuzzedParameter, FuzzedTensor, ParameterAlias
+
+
+__all__ = ['SpectralOpFuzzer']
+
+MIN_DIM_SIZE = 16
+MAX_DIM_SIZE = 16 * 1024
+
+def power_range(upper_bound, base):
+    return (base ** i for i in range(int(math.log(upper_bound, base)) + 1))
+
+# List of regular numbers from MIN_DIM_SIZE to MAX_DIM_SIZE
+# These numbers factorize into multiples of prime factors 2, 3, and 5 only
+# and are usually the fastest in FFT implementations.
+REGULAR_SIZES = []
+for i in power_range(MAX_DIM_SIZE, 2):
+    for j in power_range(MAX_DIM_SIZE // i, 3):
+        ij = i * j
+        for k in power_range(MAX_DIM_SIZE // ij, 5):
+            ijk = ij * k
+            if ijk > MIN_DIM_SIZE:
+                REGULAR_SIZES.append(ijk)
+REGULAR_SIZES.sort()
+
+class SpectralOpFuzzer(benchmark.Fuzzer):
+    def __init__(self, *, seed: int, dtype=torch.float64,
+                 cuda: bool = False, probability_regular: float = 1.0) -> None:
+        super().__init__(
+            parameters=[
+                # Dimensionality of x. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("ndim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+
+                # Shapes for `x`.
+                #   It is important to test all shapes, however
+                #   regular sizes are especially important to the FFT and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the regular numbers
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=MIN_DIM_SIZE,
+                        maxval=MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_regular_{i}",
+                        distribution={size: 1. / len(REGULAR_SIZES) for size in REGULAR_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_regular_{i}"): probability_regular,
+                            ParameterAlias(f"k_any_{i}"): 1 - probability_regular,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                # Steps for `x`. (Benchmarks strided memory access.)
+                [
+                    FuzzedParameter(
+                        name=f"step_{i}",
+                        distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04},
+                    ) for i in range(3)
+                ],
+            ],
+            tensors=[
+                FuzzedTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    steps=("step_0", "step_1", "step_2"),
+                    probability_contiguous=0.75,
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="ndim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py
new file mode 100644
index 0000000000000000000000000000000000000000..6008adfe459218cd0e239efede5a3f1cd35ee61b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py
@@ -0,0 +1,82 @@
+# mypy: allow-untyped-defs
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor
+
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+
+class UnaryOpFuzzer(Fuzzer):
+    def __init__(self, seed, dtype=torch.float32, cuda=False) -> None:
+        super().__init__(
+            parameters=[
+                # Dimensionality of x. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+
+                # Shapes for `x`.
+                #   It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                # Steps for `x`. (Benchmarks strided memory access.)
+                [
+                    FuzzedParameter(
+                        name=f"x_step_{i}",
+                        distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04},
+                    ) for i in range(3)
+                ],
+
+                # Repeatable entropy for downstream applications.
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    steps=("x_step_0", "x_step_1", "x_step_2"),
+                    probability_contiguous=0.75,
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="dim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b8ad057eb68ed742feb70d16353a3b6bdb32e09
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6bf60f651c96ed5d5b5e439a988f71aad1eca64
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c45772d84fdd0e56d0bf15c3df8d9a30140e181
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32e4ade1a3bb57e84a29fc5e37eac0766d30bcae
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..231781053a8a40d0f2514a5ddbc934a1a02bedb0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d21f870147085d637046548e16533d55ba1bff3
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..336a0690013771bd11bda9b6620e31c434a55fe6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1216dc1bdde634799e5fa3a68a5cb7c482f3860a
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..852a52e13b8200a6dd43e71ee0cae1158ad56467
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c91e3d12b29e1c050edbadaebb877d7fc0761e57
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py
@@ -0,0 +1,42 @@
+from typing import Any
+from collections.abc import Callable
+from typing_extensions import Protocol, runtime_checkable
+
+
+class TimerClass(Protocol):
+    """This is the portion of the `timeit.Timer` API used by benchmark utils."""
+    def __init__(
+        self,
+        stmt: str,
+        setup: str,
+        timer: Callable[[], float],
+        globals: dict[str, Any],
+        **kwargs: Any,
+    ) -> None:
+        ...
+
+    def timeit(self, number: int) -> float:
+        ...
+
+
+@runtime_checkable
+class TimeitModuleType(Protocol):
+    """Modules generated from `timeit_template.cpp`."""
+    def timeit(self, number: int) -> float:
+        ...
+
+
+class CallgrindModuleType(Protocol):
+    """Replicates the valgrind endpoints in `torch._C`.
+
+    These bindings are used to collect Callgrind profiles on earlier versions
+    of PyTorch and will eventually be removed.
+    """
+    __file__: str
+    __name__: str
+
+    def _valgrind_supported_platform(self) -> bool:
+        ...
+
+    def _valgrind_toggle(self) -> None:
+        ...
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4f328d19083f0fc92da79e34d70a68b8ef891ff
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py
@@ -0,0 +1,359 @@
+"""Base shared classes and utilities."""
+
+import collections
+import contextlib
+import dataclasses
+import os
+import shutil
+import tempfile
+import textwrap
+import time
+from typing import cast, Any
+from collections.abc import Iterable, Iterator
+import uuid
+
+import torch
+
+
+__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"]
+
+
+_MAX_SIGNIFICANT_FIGURES = 4
+_MIN_CONFIDENCE_INTERVAL = 25e-9  # 25 ns
+
+# Measurement will include a warning if the distribution is suspect. All
+# runs are expected to have some variation; these parameters set the
+# thresholds.
+_IQR_WARN_THRESHOLD = 0.1
+_IQR_GROSS_WARN_THRESHOLD = 0.25
+
+
+@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True)
+class TaskSpec:
+    """Container for information used to define a Timer. (except globals)"""
+    stmt: str
+    setup: str
+    global_setup: str = ""
+    label: str | None = None
+    sub_label: str | None = None
+    description: str | None = None
+    env: str | None = None
+    num_threads: int = 1
+
+    @property
+    def title(self) -> str:
+        """Best effort attempt at a string label for the measurement."""
+        if self.label is not None:
+            return self.label + (f": {self.sub_label}" if self.sub_label else "")
+        elif "\n" not in self.stmt:
+            return self.stmt + (f": {self.sub_label}" if self.sub_label else "")
+        return (
+            f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n"
+            f"{textwrap.indent(self.stmt, '  ')}"
+        )
+
+    def setup_str(self) -> str:
+        return (
+            "" if (self.setup == "pass" or not self.setup)
+            else f"setup:\n{textwrap.indent(self.setup, '  ')}" if "\n" in self.setup
+            else f"setup: {self.setup}"
+        )
+
+    def summarize(self) -> str:
+        """Build TaskSpec portion of repr string for other containers."""
+        sections = [
+            self.title,
+            self.description or "",
+            self.setup_str(),
+        ]
+        return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i])
+
+_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec))
+
+
+@dataclasses.dataclass(init=True, repr=False)
+class Measurement:
+    """The result of a Timer measurement.
+
+    This class stores one or more measurements of a given statement. It is
+    serializable and provides several convenience methods
+    (including a detailed __repr__) for downstream consumers.
+    """
+    number_per_run: int
+    raw_times: list[float]
+    task_spec: TaskSpec
+    metadata: dict[Any, Any] | None = None  # Reserved for user payloads.
+
+    def __post_init__(self) -> None:
+        self._sorted_times: tuple[float, ...] = ()
+        self._warnings: tuple[str, ...] = ()
+        self._median: float = -1.0
+        self._mean: float = -1.0
+        self._p25: float = -1.0
+        self._p75: float = -1.0
+
+    def __getattr__(self, name: str) -> Any:
+        # Forward TaskSpec fields for convenience.
+        if name in _TASKSPEC_FIELDS:
+            return getattr(self.task_spec, name)
+        return super().__getattribute__(name)
+
+    # =========================================================================
+    # == Convenience methods for statistics ===================================
+    # =========================================================================
+    #
+    # These methods use raw time divided by number_per_run; this is an
+    # extrapolation and hides the fact that different number_per_run will
+    # result in different amortization of overheads, however if Timer has
+    # selected an appropriate number_per_run then this is a non-issue, and
+    # forcing users to handle that division would result in a poor experience.
+    @property
+    def times(self) -> list[float]:
+        return [t / self.number_per_run for t in self.raw_times]
+
+    @property
+    def median(self) -> float:
+        self._lazy_init()
+        return self._median
+
+    @property
+    def mean(self) -> float:
+        self._lazy_init()
+        return self._mean
+
+    @property
+    def iqr(self) -> float:
+        self._lazy_init()
+        return self._p75 - self._p25
+
+    @property
+    def significant_figures(self) -> int:
+        """Approximate significant figure estimate.
+
+        This property is intended to give a convenient way to estimate the
+        precision of a measurement. It only uses the interquartile region to
+        estimate statistics to try to mitigate skew from the tails, and
+        uses a static z value of 1.645 since it is not expected to be used
+        for small values of `n`, so z can approximate `t`.
+
+        The significant figure estimation used in conjunction with the
+        `trim_sigfig` method to provide a more human interpretable data
+        summary. __repr__ does not use this method; it simply displays raw
+        values. Significant figure estimation is intended for `Compare`.
+        """
+        self._lazy_init()
+        n_total = len(self._sorted_times)
+        lower_bound = int(n_total // 4)
+        upper_bound = int(torch.tensor(3 * n_total / 4).ceil())
+        interquartile_points: tuple[float, ...] = self._sorted_times[lower_bound:upper_bound]
+        std = torch.tensor(interquartile_points).std(unbiased=False).item()
+        sqrt_n = torch.tensor(len(interquartile_points)).sqrt().item()
+
+        # Rough estimates. These are by no means statistically rigorous.
+        confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL)
+        relative_ci = torch.tensor(self._median / confidence_interval).log10().item()
+        num_significant_figures = int(torch.tensor(relative_ci).floor())
+        return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES)
+
+    @property
+    def has_warnings(self) -> bool:
+        self._lazy_init()
+        return bool(self._warnings)
+
+    def _lazy_init(self) -> None:
+        if self.raw_times and not self._sorted_times:
+            self._sorted_times = tuple(sorted(self.times))
+            _sorted_times = torch.tensor(self._sorted_times, dtype=torch.float64)
+            self._median = _sorted_times.quantile(.5).item()
+            self._mean = _sorted_times.mean().item()
+            self._p25 = _sorted_times.quantile(.25).item()
+            self._p75 = _sorted_times.quantile(.75).item()
+
+            def add_warning(msg: str) -> None:
+                rel_iqr = self.iqr / self.median * 100
+                self._warnings += (
+                    f"  WARNING: Interquartile range is {rel_iqr:.1f}% "
+                    f"of the median measurement.\n           {msg}",
+                )
+
+            if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD):
+                add_warning("This suggests significant environmental influence.")
+            elif not self.meets_confidence(_IQR_WARN_THRESHOLD):
+                add_warning("This could indicate system fluctuation.")
+
+
+    def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool:
+        return self.iqr / self.median < threshold
+
+    @property
+    def title(self) -> str:
+        return self.task_spec.title
+
+    @property
+    def env(self) -> str:
+        return (
+            "Unspecified env" if self.taskspec.env is None
+            else cast(str, self.taskspec.env)
+        )
+
+    @property
+    def as_row_name(self) -> str:
+        return self.sub_label or self.stmt or "[Unknown]"
+
+    def __repr__(self) -> str:
+        """
+        Example repr:
+            
+              Broadcasting add (4x8)
+              Median: 5.73 us
+              IQR:    2.25 us (4.01 to 6.26)
+              372 measurements, 100 runs per measurement, 1 thread
+              WARNING: Interquartile range is 39.4% of the median measurement.
+                       This suggests significant environmental influence.
+        """
+        self._lazy_init()
+        skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n"
+        n = len(self._sorted_times)
+        time_unit, time_scale = select_unit(self._median)
+        iqr_filter = '' if n >= 4 else skip_line
+
+        repr_str = f"""
+{super().__repr__()}
+{self.task_spec.summarize()}
+  {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit}
+  {iqr_filter}IQR:    {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f})
+  {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''}
+{newline.join(self._warnings)}""".strip()  # noqa: B950
+
+        return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l)
+
+    @staticmethod
+    def merge(measurements: Iterable["Measurement"]) -> list["Measurement"]:
+        """Convenience method for merging replicates.
+
+        Merge will extrapolate times to `number_per_run=1` and will not
+        transfer any metadata. (Since it might differ between replicates)
+        """
+        grouped_measurements: collections.defaultdict[TaskSpec, list[Measurement]] = collections.defaultdict(list)
+        for m in measurements:
+            grouped_measurements[m.task_spec].append(m)
+
+        def merge_group(task_spec: TaskSpec, group: list["Measurement"]) -> "Measurement":
+            times: list[float] = []
+            for m in group:
+                # Different measurements could have different `number_per_run`,
+                # so we call `.times` which normalizes the results.
+                times.extend(m.times)
+
+            return Measurement(
+                number_per_run=1,
+                raw_times=times,
+                task_spec=task_spec,
+                metadata=None,
+            )
+
+        return [merge_group(t, g) for t, g in grouped_measurements.items()]
+
+
+def select_unit(t: float) -> tuple[str, float]:
+    """Determine how to scale times for O(1) magnitude.
+
+    This utility is used to format numbers for human consumption.
+    """
+    time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(torch.tensor(t).log10().item() // 3), "s")
+    time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit]
+    return time_unit, time_scale
+
+
+def unit_to_english(u: str) -> str:
+    return {
+        "ns": "nanosecond",
+        "us": "microsecond",
+        "ms": "millisecond",
+        "s": "second",
+    }[u]
+
+
+def trim_sigfig(x: float, n: int) -> float:
+    """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)"""
+    if n != int(n):
+        raise AssertionError("Number of significant figures must be an integer")
+    magnitude = int(torch.tensor(x).abs().log10().ceil().item())
+    scale = 10 ** (magnitude - n)
+    return float(torch.tensor(x / scale).round() * scale)
+
+
+def ordered_unique(elements: Iterable[Any]) -> list[Any]:
+    return list(collections.OrderedDict(dict.fromkeys(elements)).keys())
+
+
+@contextlib.contextmanager
+def set_torch_threads(n: int) -> Iterator[None]:
+    prior_num_threads = torch.get_num_threads()
+    try:
+        torch.set_num_threads(n)
+        yield
+    finally:
+        torch.set_num_threads(prior_num_threads)
+
+
+def _make_temp_dir(prefix: str | None = None, gc_dev_shm: bool = False) -> str:
+    """Create a temporary directory. The caller is responsible for cleanup.
+
+    This function is conceptually similar to `tempfile.mkdtemp`, but with
+    the key additional feature that it will use shared memory if the
+    `BENCHMARK_USE_DEV_SHM` environment variable is set. This is an
+    implementation detail, but an important one for cases where many Callgrind
+    measurements are collected at once. (Such as when collecting
+    microbenchmarks.)
+
+    This is an internal utility, and is exported solely so that microbenchmarks
+    can reuse the util.
+    """
+    use_dev_shm: bool = (os.getenv("BENCHMARK_USE_DEV_SHM") or "").lower() in ("1", "true")
+    if use_dev_shm:
+        root = "/dev/shm/pytorch_benchmark_utils"
+        if os.name != "posix":
+            raise AssertionError(f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}")
+        if not os.path.exists("/dev/shm"):
+            raise AssertionError("This system does not appear to support tmpfs (/dev/shm).")
+        os.makedirs(root, exist_ok=True)
+
+        # Because we're working in shared memory, it is more important than
+        # usual to clean up ALL intermediate files. However we don't want every
+        # worker to walk over all outstanding directories, so instead we only
+        # check when we are sure that it won't lead to contention.
+        if gc_dev_shm:
+            for i in os.listdir(root):
+                owner_file = os.path.join(root, i, "owner.pid")
+                if not os.path.exists(owner_file):
+                    continue
+
+                with open(owner_file) as f:
+                    owner_pid = int(f.read())
+
+                if owner_pid == os.getpid():
+                    continue
+
+                try:
+                    # https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python
+                    os.kill(owner_pid, 0)
+
+                except OSError:
+                    print(f"Detected that {os.path.join(root, i)} was orphaned in shared memory. Cleaning up.")
+                    shutil.rmtree(os.path.join(root, i))
+
+    else:
+        root = tempfile.gettempdir()
+
+    # We include the time so names sort by creation time, and add a UUID
+    # to ensure we don't collide.
+    name = f"{prefix or tempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}"
+    path = os.path.join(root, name)
+    os.makedirs(path, exist_ok=False)
+
+    if use_dev_shm:
+        with open(os.path.join(path, "owner.pid"), "w") as f:
+            f.write(str(os.getpid()))
+
+    return path
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1e232e6e04260f277254c9b181c63dfeaadee62
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py
@@ -0,0 +1,345 @@
+# mypy: allow-untyped-defs
+"""Display class to aggregate and print the results of many measurements."""
+import collections
+import enum
+import itertools as it
+
+from torch.utils.benchmark.utils import common
+from torch import tensor as _tensor
+import operator
+
+__all__ = ["Colorize", "Compare"]
+
+BEST = "\033[92m"
+GOOD = "\033[34m"
+BAD = "\033[2m\033[91m"
+VERY_BAD = "\033[31m"
+BOLD = "\033[1m"
+TERMINATE = "\033[0m"
+
+
+class Colorize(enum.Enum):
+    NONE = "none"
+    COLUMNWISE = "columnwise"
+    ROWWISE = "rowwise"
+
+
+# Classes to separate internal bookkeeping from what is rendered.
+class _Column:
+    def __init__(
+        self,
+        grouped_results: list[tuple[common.Measurement | None, ...]],
+        time_scale: float,
+        time_unit: str,
+        trim_significant_figures: bool,
+        highlight_warnings: bool,
+    ) -> None:
+        self._grouped_results = grouped_results
+        self._flat_results = [*it.chain.from_iterable(grouped_results)]
+        self._time_scale = time_scale
+        self._time_unit = time_unit
+        self._trim_significant_figures = trim_significant_figures
+        self._highlight_warnings = (
+            highlight_warnings
+            and any(r.has_warnings for r in self._flat_results if r)
+        )
+        leading_digits = [
+            int(_tensor(r.median / self._time_scale).log10().ceil()) if r else None
+            for r in self._flat_results
+        ]
+        unit_digits = max(d for d in leading_digits if d is not None)
+        decimal_digits = min(
+            max(m.significant_figures - digits, 0)
+            for digits, m in zip(leading_digits, self._flat_results, strict=True)
+            if (m is not None) and (digits is not None)
+        ) if self._trim_significant_figures else 1
+        length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
+        self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}"
+
+    def get_results_for(self, group):
+        return self._grouped_results[group]
+
+    def num_to_str(self, value: float | None, estimated_sigfigs: int, spread: float | None):
+        if value is None:
+            return " " * len(self.num_to_str(1, estimated_sigfigs, None))
+
+        if self._trim_significant_figures:
+            value = common.trim_sigfig(value, estimated_sigfigs)
+
+        return self._template.format(
+            value,
+            f" (! {spread * 100:.0f}%)" if self._highlight_warnings and spread is not None else "")
+
+
+def optional_min(seq):
+    l = list(seq)
+    return None if len(l) == 0 else min(l)
+
+
+class _Row:
+    def __init__(self, results, row_group, render_env, env_str_len,
+                 row_name_str_len, time_scale, colorize, num_threads=None) -> None:
+        super().__init__()
+        self._results = results
+        self._row_group = row_group
+        self._render_env = render_env
+        self._env_str_len = env_str_len
+        self._row_name_str_len = row_name_str_len
+        self._time_scale = time_scale
+        self._colorize = colorize
+        self._columns: tuple[_Column, ...] = ()
+        self._num_threads = num_threads
+
+    def register_columns(self, columns: tuple[_Column, ...]) -> None:
+        self._columns = columns
+
+    def as_column_strings(self):
+        concrete_results = [r for r in self._results if r is not None]
+        env = f"({concrete_results[0].env})" if self._render_env else ""
+        env = env.ljust(self._env_str_len + 4)
+        output = ["  " + env + concrete_results[0].as_row_name]
+        for m, col in zip(self._results, self._columns or (), strict=False):
+            if m is None:
+                output.append(col.num_to_str(None, 1, None))
+            else:
+                output.append(col.num_to_str(
+                    m.median / self._time_scale,
+                    m.significant_figures,
+                    m.iqr / m.median if m.has_warnings else None
+                ))
+        return output
+
+    @staticmethod
+    def color_segment(segment, value, best_value):
+        if value <= best_value * 1.01 or value <= best_value + 100e-9:
+            return BEST + BOLD + segment + TERMINATE * 2
+        if value <= best_value * 1.1:
+            return GOOD + BOLD + segment + TERMINATE * 2
+        if value >= best_value * 5:
+            return VERY_BAD + BOLD + segment + TERMINATE * 2
+        if value >= best_value * 2:
+            return BAD + segment + TERMINATE * 2
+
+        return segment
+
+    def row_separator(self, overall_width):
+        return (
+            [f"{self._num_threads} threads: ".ljust(overall_width, "-")]
+            if self._num_threads is not None else []
+        )
+
+    def finalize_column_strings(self, column_strings, col_widths):
+        best_values = [-1 for _ in column_strings]
+        if self._colorize == Colorize.ROWWISE:
+            row_min = min(r.median for r in self._results if r is not None)
+            best_values = [row_min for _ in column_strings]
+        elif self._colorize == Colorize.COLUMNWISE:
+            best_values = [
+                optional_min(r.median for r in column.get_results_for(self._row_group) if r is not None)
+                for column in (self._columns or ())
+            ]
+
+        row_contents = [column_strings[0].ljust(col_widths[0])]
+        for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values, strict=False):
+            col_str = col_str.center(width)
+            if self._colorize != Colorize.NONE and result is not None and best_value is not None:
+                col_str = self.color_segment(col_str, result.median, best_value)
+            row_contents.append(col_str)
+        return row_contents
+
+
+class Table:
+    def __init__(
+            self,
+            results: list[common.Measurement],
+            colorize: Colorize,
+            trim_significant_figures: bool,
+            highlight_warnings: bool
+    ) -> None:
+        if len({r.label for r in results}) != 1:
+            raise AssertionError("All results must share the same label")
+
+        self.results = results
+        self._colorize = colorize
+        self._trim_significant_figures = trim_significant_figures
+        self._highlight_warnings = highlight_warnings
+        self.label = results[0].label
+        self.time_unit, self.time_scale = common.select_unit(
+            min(r.median for r in results)
+        )
+
+        self.row_keys = common.ordered_unique([self.row_fn(i) for i in results])
+        self.row_keys.sort(key=operator.itemgetter(slice(2)))  # preserve stmt order
+        self.column_keys = common.ordered_unique([self.col_fn(i) for i in results])
+        self.rows, self.columns = self.populate_rows_and_columns()
+
+    @staticmethod
+    def row_fn(m: common.Measurement) -> tuple[int, str | None, str]:
+        return m.num_threads, m.env, m.as_row_name
+
+    @staticmethod
+    def col_fn(m: common.Measurement) -> str | None:
+        return m.description
+
+    def populate_rows_and_columns(self) -> tuple[tuple[_Row, ...], tuple[_Column, ...]]:
+        rows: list[_Row] = []
+        columns: list[_Column] = []
+        ordered_results: list[list[common.Measurement | None]] = [
+            [None for _ in self.column_keys]
+            for _ in self.row_keys
+        ]
+        row_position = {key: i for i, key in enumerate(self.row_keys)}
+        col_position = {key: i for i, key in enumerate(self.column_keys)}
+        for r in self.results:
+            i = row_position[self.row_fn(r)]
+            j = col_position[self.col_fn(r)]
+            ordered_results[i][j] = r
+
+        unique_envs = {r.env for r in self.results}
+        render_env = len(unique_envs) > 1
+        env_str_len = max(len(i) for i in unique_envs) if render_env else 0
+
+        row_name_str_len = max(len(r.as_row_name) for r in self.results)
+
+        prior_num_threads = -1
+        prior_env = ""
+        row_group = -1
+        rows_by_group: list[list[list[common.Measurement | None]]] = []
+        for (num_threads, env, _), row in zip(self.row_keys, ordered_results, strict=True):
+            thread_transition = (num_threads != prior_num_threads)
+            if thread_transition:
+                prior_num_threads = num_threads
+                prior_env = ""
+                row_group += 1
+                rows_by_group.append([])
+            rows.append(
+                _Row(
+                    results=row,
+                    row_group=row_group,
+                    render_env=(render_env and env != prior_env),
+                    env_str_len=env_str_len,
+                    row_name_str_len=row_name_str_len,
+                    time_scale=self.time_scale,
+                    colorize=self._colorize,
+                    num_threads=num_threads if thread_transition else None,
+                )
+            )
+            rows_by_group[-1].append(row)
+            prior_env = env
+
+        for i in range(len(self.column_keys)):
+            grouped_results = [tuple(row[i] for row in g) for g in rows_by_group]
+            column = _Column(
+                grouped_results=grouped_results,
+                time_scale=self.time_scale,
+                time_unit=self.time_unit,
+                trim_significant_figures=self._trim_significant_figures,
+                highlight_warnings=self._highlight_warnings,)
+            columns.append(column)
+
+        rows_tuple, columns_tuple = tuple(rows), tuple(columns)
+        for ri in rows_tuple:
+            ri.register_columns(columns_tuple)
+        return rows_tuple, columns_tuple
+
+    def render(self) -> str:
+        string_rows = [[""] + self.column_keys]
+        string_rows.extend(r.as_column_strings() for r in self.rows)
+        num_cols = max(len(i) for i in string_rows)
+        for sr in string_rows:
+            sr.extend(["" for _ in range(num_cols - len(sr))])
+
+        col_widths = [max(len(j) for j in i) for i in zip(*string_rows, strict=True)]
+        finalized_columns = ["  |  ".join(i.center(w) for i, w in zip(string_rows[0], col_widths, strict=True))]
+        overall_width = len(finalized_columns[0])
+        for string_row, row in zip(string_rows[1:], self.rows, strict=True):
+            finalized_columns.extend(row.row_separator(overall_width))
+            finalized_columns.append("  |  ".join(row.finalize_column_strings(string_row, col_widths)))
+
+        newline = "\n"
+        has_warnings = self._highlight_warnings and any(ri.has_warnings for ri in self.results)
+        return f"""
+[{(' ' + (self.label or '') + ' ').center(overall_width - 2, '-')}]
+{newline.join(finalized_columns)}
+
+Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).
+{'(! XX%) Measurement has high variance, where XX is the IQR / median * 100.' + newline if has_warnings else ""}"""[1:]
+
+
+class Compare:
+    """Helper class for displaying the results of many measurements in a
+    formatted table.
+
+    The table format is based on the information fields provided in
+    :class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`,
+    `num_threads`, etc).
+
+    The table can be directly printed using :meth:`print` or casted as a `str`.
+
+    For a full tutorial on how to use this class, see:
+    https://pytorch.org/tutorials/recipes/recipes/benchmark.html
+
+    Args:
+        results: List of Measurement to display.
+    """
+    def __init__(self, results: list[common.Measurement]) -> None:
+        self._results: list[common.Measurement] = []
+        self.extend_results(results)
+        self._trim_significant_figures = False
+        self._colorize = Colorize.NONE
+        self._highlight_warnings = False
+
+    def __str__(self) -> str:
+        return "\n".join(self._render())
+
+    def extend_results(self, results) -> None:
+        """Append results to already stored ones.
+
+        All added results must be instances of ``Measurement``.
+        """
+        for r in results:
+            if not isinstance(r, common.Measurement):
+                raise ValueError(
+                    "Expected an instance of `Measurement`, " f"got {type(r)} instead."
+                )
+        self._results.extend(results)
+
+    def trim_significant_figures(self) -> None:
+        """Enables trimming of significant figures when building the formatted table."""
+        self._trim_significant_figures = True
+
+    def colorize(self, rowwise=False) -> None:
+        """Colorize formatted table.
+
+        Colorize columnwise by default.
+        """
+        self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE
+
+    def highlight_warnings(self) -> None:
+        """Enables warning highlighting when building formatted table."""
+        self._highlight_warnings = True
+
+    def print(self) -> None:
+        """Print formatted table"""
+        print(str(self))
+
+    def _render(self):
+        results = common.Measurement.merge(self._results)
+        grouped_results = self._group_by_label(results)
+        output = [self._layout(group) for group in grouped_results.values()]
+        return output
+
+    def _group_by_label(self, results: list[common.Measurement]):
+        grouped_results: collections.defaultdict[str, list[common.Measurement]] = collections.defaultdict(list)
+        for r in results:
+            grouped_results[r.label].append(r)
+        return grouped_results
+
+    def _layout(self, results: list[common.Measurement]):
+        table = Table(
+            results,
+            self._colorize,
+            self._trim_significant_figures,
+            self._highlight_warnings
+        )
+        return table.render()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd15a582a274980bea4aff22f7325ccf562ecb13
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py
@@ -0,0 +1,195 @@
+# mypy: allow-untyped-defs
+from typing import Any, cast
+from collections.abc import Callable
+
+import torch
+import torch._dynamo
+from torch._dynamo.testing import CompileCounterWithBackend
+from torch.utils.benchmark import Timer
+
+
+__all__ = ["bench_all", "benchmark_compile"]
+
+
+_warned_tensor_cores = False
+_default_float_32_precision = torch.get_float32_matmul_precision()
+
+try:
+
+    from tabulate import tabulate
+
+    HAS_TABULATE = True
+except ModuleNotFoundError:
+    HAS_TABULATE = False
+    tabulate = None  # type: ignore[assignment]
+    print("tabulate is not installed, please pip install tabulate to use this utility")
+
+if HAS_TABULATE:
+    def _enable_tensor_cores() -> None:
+        global _warned_tensor_cores
+
+        if torch.cuda.is_available():
+            if torch.backends.cuda.matmul.allow_tf32 is False and torch.cuda.get_device_capability() >= (8, 0):
+                torch.set_float32_matmul_precision("high")
+                if not _warned_tensor_cores:
+                    print("Your GPU supports tensor cores")
+                    print("we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`")
+                    _warned_tensor_cores = True
+
+    def _disable_tensor_cores() -> None:
+        torch.set_float32_matmul_precision(_default_float_32_precision)
+
+    def bench_loop(
+        model: torch.nn.Module | Callable,
+        sample_input: torch.Tensor | Any,
+        num_iters: int = 5,
+        optimizer: torch.optim.Optimizer | None = None,
+        loss_fn: Callable | None = None,
+    ):
+        # Define the statement and setup for the benchmark
+        if optimizer and loss_fn:
+            # Training mode
+            stmt = """
+    output = model(sample_input)
+    loss = loss_fn(output) if loss_fn else output.sum()
+    loss.backward()
+    optimizer.step()
+    optimizer.zero_grad()
+            """
+        else:
+            # Inference mode
+            stmt = "model(sample_input)"
+
+        # Create the Timer object
+        timer = Timer(
+            stmt=stmt,
+            globals={"model": model, "sample_input": sample_input, "optimizer": optimizer, "loss_fn": loss_fn},
+        )
+
+
+        result = timer.timeit(number=num_iters)
+
+        # Get the average time per iteration in milliseconds
+        avg_time = result.mean * 1000
+        return round(avg_time, 2)
+
+    def benchmark_compile(
+        model: torch.nn.Module | Callable,
+        sample_input: torch.Tensor | Any,
+        num_iters: int = 5,
+        backend: str | None = None,
+        mode: str | None = "default",
+        optimizer: torch.optim.Optimizer | None = None,
+        loss_fn : torch.nn.Module | Callable | None = None,
+    ):
+        """
+        Use this utility to benchmark torch.compile
+        """
+        if backend:
+            try:
+                torch._dynamo.reset()
+                compile_counter_with_backend = CompileCounterWithBackend(backend)
+                opt_model = torch.compile(model, backend=compile_counter_with_backend, mode=mode)
+
+                # Compilation only happens after the first inference
+                compilation_time = bench_loop(opt_model, sample_input, 1, optimizer, loss_fn)
+
+                running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn)
+
+                if compile_counter_with_backend.frame_count == 0:
+                    raise RuntimeError("No compilation occurred during benchmarking.")
+
+                if compile_counter_with_backend.frame_count > 1:
+                    raise RuntimeError("Recompilation occurred during benchmarking.")
+
+            except Exception as e:
+                print(e)
+                print(f"Failed to compile {backend} with mode {mode}")
+                return None, None
+        else:
+            opt_model = model
+            compilation_time = None
+            running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn)
+
+        compilation_time = round(compilation_time, 2) if compilation_time else None
+        running_time = round(running_time, 2) if running_time else None
+
+
+        return compilation_time, running_time
+
+
+    def bench_all(
+        model : torch.nn.Module | Callable,
+        sample_input: torch.Tensor | Any,
+        num_iters : int = 5,
+        optimizer: torch.optim.Optimizer | None = None,
+        loss_fn : torch.nn.Module | Callable | None = None,
+    ):
+        """
+        This is a simple utility that can be used to benchmark torch.compile
+        In particular it ensures that your GPU is setup to use tensor cores if it supports its
+        It also tries out all the main backends and prints a table of results so you can easily compare them all
+        Many of the backendds have their own optional dependencies so please pip install them separately
+
+        You will get one table for inference and another for training
+        If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer
+
+        The important warnings are
+        Your GPU supports tensor cores
+        we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`
+
+        If a compilation fails for any reason including the dependency not being included
+        then we will print Failed to compile {backend} with mode {mode}
+        """
+        field_names = ["Train/Inference", "Backend", "Mode", "Compilation Time", "Average Running Time"]
+        table = []
+
+
+        eager_time = None
+        torch._dynamo.reset()
+        _, eager_time = benchmark_compile(model, sample_input, num_iters, None, None, optimizer)
+        table.append(
+            [("Training" if optimizer else "Inference"), "Eager", "-", "-", f"{eager_time} ms"]
+        )
+
+        for backend in torch._dynamo.list_backends():
+
+            if backend == "inductor":
+                mode_options = cast(list[str | None], list(torch._inductor.list_mode_options().keys())) + [None]
+                for mode in mode_options:
+                    if mode == "default":
+                        continue
+                    torch._dynamo.reset()
+                    try:
+                        if torch.cuda.is_available():
+                            _enable_tensor_cores()
+                        compilation_time, running_time = benchmark_compile(
+                            model, sample_input, num_iters, backend, mode, optimizer, loss_fn)
+                    finally:
+                        if torch.cuda.is_available():
+                            _disable_tensor_cores()
+                            table.append([
+                                ("Training" if optimizer else "Inference"),
+                                # pyrefly: ignore [redundant-condition]
+                                backend if backend else "-",
+                                mode if mode is not None else "-",
+                                f"{compilation_time} ms " if compilation_time else "-",
+                                f"{running_time} ms " if running_time else "-",
+                            ])
+
+            else:
+                torch._dynamo.reset()
+                compilation_time, running_time = benchmark_compile(
+                    model, sample_input, num_iters, backend, None, optimizer, loss_fn)
+
+                if running_time is not None:
+                    table.append([
+                        ("Training" if optimizer else "Inference"),
+                        backend, "-",
+                        f"{compilation_time} ms " or "-",
+                        f"{running_time} ms ",
+                    ])
+
+
+        # pyrefly: ignore [not-callable]
+        return tabulate(table, headers=field_names, tablefmt="github")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..a298146ce17c7ff6f303b4d76c4c96ba786ae774
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py
@@ -0,0 +1,175 @@
+"""JIT C++ strings into executables."""
+import atexit
+import os
+import re
+import shutil
+import textwrap
+import threading
+from typing import Any
+
+import torch
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType
+from torch.utils.benchmark.utils.common import _make_temp_dir
+from torch.utils import cpp_extension
+
+
+LOCK = threading.Lock()
+SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]
+
+# We calculate uuid once at import time so that separate processes will have
+# separate build roots, but threads will share the same build root.
+# `cpp_extension` uses build root as part of the cache key, so per-invocation
+# uuid's (e.g. different build root per _compile_template call) would lead to
+# a 0% cache hit rate and spurious recompilation. Consider the following:
+#   ```
+#   setup = "auto x = torch::ones({1024, 1024});"
+#   stmt = "torch::mm(x, x);"
+#   for num_threads in [1, 2, 4, 8]:
+#     print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange())
+#   ````
+# `setup` and `stmt` do not change, so we can reuse the executable from the
+# first pass through the loop.
+_BUILD_ROOT: str | None = None
+
+def _get_build_root() -> str:
+    global _BUILD_ROOT
+    if _BUILD_ROOT is None:
+        _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
+        # pyrefly: ignore [missing-argument]
+        atexit.register(shutil.rmtree, _BUILD_ROOT)
+    return _BUILD_ROOT
+
+
+# BACK_TESTING_NOTE:
+#   There are two workflows where this code could be used. One is the obvious
+#   case where someone simply builds or installs PyTorch and uses Timer.
+#   The other is that the entire `torch/utils/benchmark` folder from a CURRENT
+#   PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch
+#   source code. This is what we refer to here as "back testing". The rationale
+#   is that we might want to use current tooling to study some aspect of an
+#   earlier version of PyTorch. (e.g. a regression.)
+#
+#   The problem is that Timer relies on several aspects of core PyTorch, namely
+#   some binding functions for Valgrind symbols in `torch._C` and the
+#   `torch.__config__._cxx_flags()` method. If we were to naively copy code
+#   around this wouldn't work as the symbols of interest aren't present in
+#   earlier versions of PyTorch. In order to work around this, we must add back
+#   testing shims. These shims will never activate during normal use, but will
+#   allow Timer to function outside of the "correct" version of PyTorch by
+#   emulating functionality that was added later.
+#
+#   These shims are temporary, and as Timer becomes more integrated with
+#   PyTorch the cost and complexity of such shims will increase. Once back
+#   testing is no longer required (which is to say we have done enough historic
+#   analysis and the shims no longer justify their maintenance and code
+#   complexity costs) back testing paths will be removed.
+
+CXX_FLAGS: list[str] | None
+if hasattr(torch.__config__, "_cxx_flags"):
+    try:
+        CXX_FLAGS = torch.__config__._cxx_flags().strip().split()
+        if CXX_FLAGS is not None and "-g" not in CXX_FLAGS:
+            CXX_FLAGS.append("-g")
+        # remove "-W" flags to allow build benchmarks
+        # with a relaxed constraint of compiler versions
+        if CXX_FLAGS is not None:
+            CXX_FLAGS = list(filter(lambda x: not x.startswith("-W"), CXX_FLAGS))
+
+    except RuntimeError:
+        # We are in FBCode.
+        CXX_FLAGS = None
+else:
+    # FIXME: Remove when back testing is no longer required.
+    CXX_FLAGS = ["-O2", "-fPIC", "-g"]
+
+EXTRA_INCLUDE_PATHS: list[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")]
+CONDA_PREFIX = os.getenv("CONDA_PREFIX")
+if CONDA_PREFIX is not None:
+    # Load will automatically search /usr/include, but not conda include.
+    EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include"))
+
+
+COMPAT_CALLGRIND_BINDINGS: CallgrindModuleType | None = None
+def get_compat_bindings() -> CallgrindModuleType:
+    with LOCK:
+        global COMPAT_CALLGRIND_BINDINGS
+        if COMPAT_CALLGRIND_BINDINGS is None:
+            COMPAT_CALLGRIND_BINDINGS = cpp_extension.load(
+                name="callgrind_bindings",
+                sources=[os.path.join(
+                    SOURCE_ROOT,
+                    "valgrind_wrapper",
+                    "compat_bindings.cpp"
+                )],
+                extra_cflags=CXX_FLAGS,
+                extra_include_paths=EXTRA_INCLUDE_PATHS,
+            )
+    return COMPAT_CALLGRIND_BINDINGS
+
+
+def _compile_template(
+    *,
+    stmt: str,
+    setup: str,
+    global_setup: str,
+    src: str,
+    is_standalone: bool
+) -> Any:
+    for before, after, indentation in (
+        ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0),
+        ("// SETUP_TEMPLATE_LOCATION", setup, 4),
+        ("// STMT_TEMPLATE_LOCATION", stmt, 8)
+    ):
+        # C++ doesn't care about indentation so this code isn't load
+        # bearing the way it is with Python, but this makes the source
+        # look nicer if a human has to look at it.
+        src = re.sub(
+            before,
+            textwrap.indent(after, " " * indentation)[indentation:],
+            src
+        )
+
+    # We want to isolate different Timers. However `cpp_extension` will
+    # cache builds which will significantly reduce the cost of repeated
+    # invocations.
+    with LOCK:
+        name = f"timer_cpp_{abs(hash(src))}"
+        build_dir = os.path.join(_get_build_root(), name)
+        os.makedirs(build_dir, exist_ok=True)
+
+        src_path = os.path.join(build_dir, "timer_src.cpp")
+        with open(src_path, "w") as f:
+            f.write(src)
+
+    # `cpp_extension` has its own locking scheme, so we don't need our lock.
+    return cpp_extension.load(
+        name=name,
+        sources=[src_path],
+        build_directory=build_dir,
+        extra_cflags=CXX_FLAGS,
+        extra_include_paths=EXTRA_INCLUDE_PATHS,
+        is_python_module=not is_standalone,
+        is_standalone=is_standalone,
+    )
+
+
+def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
+    template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
+    with open(template_path) as f:
+        src: str = f.read()
+
+    module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
+    if not isinstance(module, TimeitModuleType):
+        raise AssertionError("compiled module is not a TimeitModuleType")
+    return module
+
+
+def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
+    template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
+    with open(template_path) as f:
+        src: str = f.read()
+
+    target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
+    if not isinstance(target, str):
+        raise AssertionError("compiled target path is not a string")
+    return target
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..38f771d23632efd27239e460591d923be3ee59fc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py
@@ -0,0 +1,469 @@
+# mypy: allow-untyped-defs
+import functools
+import itertools as it
+from typing import Any
+from collections.abc import Callable
+
+import torch
+
+
+__all__ = [
+    "Fuzzer",
+    "FuzzedParameter", "ParameterAlias",
+    "FuzzedTensor",
+]
+
+
+_DISTRIBUTIONS = (
+    "loguniform",
+    "uniform",
+)
+
+
+class FuzzedParameter:
+    """Specification for a parameter to be generated during fuzzing."""
+    def __init__(
+        self,
+        name: str,
+        minval: int | float | None = None,
+        maxval: int | float | None = None,
+        distribution: str | dict[Any, float] | None = None,
+        strict: bool = False,
+    ) -> None:
+        """
+        Args:
+            name:
+                A string name with which to identify the parameter.
+                FuzzedTensors can reference this string in their
+                specifications.
+            minval:
+                The lower bound for the generated value. See the description
+                of `distribution` for type behavior.
+            maxval:
+                The upper bound for the generated value. Type behavior is
+                identical to `minval`.
+            distribution:
+                Specifies the distribution from which this parameter should
+                be drawn. There are three possibilities:
+                    - "loguniform"
+                        Samples between `minval` and `maxval` (inclusive) such
+                        that the probabilities are uniform in log space. As a
+                        concrete example, if minval=1 and maxval=100, a sample
+                        is as likely to fall in [1, 10) as it is [10, 100].
+                    - "uniform"
+                        Samples are chosen with uniform probability between
+                        `minval` and `maxval` (inclusive). If either `minval`
+                        or `maxval` is a float then the distribution is the
+                        continuous uniform distribution; otherwise samples
+                        are constrained to the integers.
+                    - dict:
+                        If a dict is passed, the keys are taken to be choices
+                        for the variables and the values are interpreted as
+                        probabilities. (And must sum to one.)
+                If a dict is passed, `minval` and `maxval` must not be set.
+                Otherwise, they must be set.
+            strict:
+                If a parameter is strict, it will not be included in the
+                iterative resampling process which Fuzzer uses to find a
+                valid parameter configuration. This allows an author to
+                prevent skew from resampling for a given parameter (for
+                instance, a low size limit could inadvertently bias towards
+                Tensors with fewer dimensions) at the cost of more iterations
+                when generating parameters.
+        """
+        self._name = name
+        self._minval = minval
+        self._maxval = maxval
+        self._distribution = self._check_distribution(distribution)
+        self.strict = strict
+
+    @property
+    def name(self):
+        return self._name
+
+    def sample(self, state):
+        if self._distribution == "loguniform":
+            return self._loguniform(state)
+
+        if self._distribution == "uniform":
+            return self._uniform(state)
+
+        if isinstance(self._distribution, dict):
+            return self._custom_distribution(state)
+
+    def _check_distribution(self, distribution):
+        if not isinstance(distribution, dict):
+            if distribution not in _DISTRIBUTIONS:
+                raise AssertionError(f"Unknown distribution: {distribution}")
+        else:
+            if any(i < 0 for i in distribution.values()):
+                raise AssertionError("Probabilities cannot be negative")
+            if not abs(sum(distribution.values()) - 1) > 1e-5:
+                raise AssertionError("Distribution is not normalized")
+            if self._minval is not None:
+                raise AssertionError("When passing a custom distribution, 'minval' must be None")
+            if self._maxval is not None:
+                raise AssertionError("When passing a custom distribution, 'maxval' must be None")
+
+        return distribution
+
+    def _loguniform(self, state):
+        import numpy as np
+        output = int(2 ** state.uniform(
+            low=np.log2(self._minval) if self._minval is not None else None,
+            high=np.log2(self._maxval) if self._maxval is not None else None,
+        ))
+        if self._minval is not None and output < self._minval:
+            return self._minval
+        if self._maxval is not None and output > self._maxval:
+            return self._maxval
+        return output
+
+    def _uniform(self, state):
+        if isinstance(self._minval, int) and isinstance(self._maxval, int):
+            return int(state.randint(low=self._minval, high=self._maxval + 1))
+        return state.uniform(low=self._minval, high=self._maxval)
+
+    def _custom_distribution(self, state):
+        import numpy as np
+        # If we directly pass the keys to `choice`, numpy will convert
+        # them to numpy dtypes.
+        index = state.choice(
+            np.arange(len(self._distribution)),
+            p=tuple(self._distribution.values()))
+        return list(self._distribution.keys())[index]
+
+
+class ParameterAlias:
+    """Indicates that a parameter should alias the value of another parameter.
+
+    When used in conjunction with a custom distribution, this allows fuzzed
+    tensors to represent a broader range of behaviors. For example, the
+    following sometimes produces Tensors which broadcast:
+
+    Fuzzer(
+        parameters=[
+            FuzzedParameter("x_len", 4, 1024, distribution="uniform"),
+
+            # `y` will either be size one, or match the size of `x`.
+            FuzzedParameter("y_len", distribution={
+                0.5: 1,
+                0.5: ParameterAlias("x_len")
+            }),
+        ],
+        tensors=[
+            FuzzedTensor("x", size=("x_len",)),
+            FuzzedTensor("y", size=("y_len",)),
+        ],
+    )
+
+    Chains of alias' are allowed, but may not contain cycles.
+    """
+    def __init__(self, alias_to) -> None:
+        self.alias_to = alias_to
+
+    def __repr__(self) -> str:
+        return f"ParameterAlias[alias_to: {self.alias_to}]"
+
+
+def dtype_size(dtype):
+    if dtype == torch.bool:
+        return 1
+    if dtype.is_floating_point or dtype.is_complex:
+        return int(torch.finfo(dtype).bits / 8)
+    return int(torch.iinfo(dtype).bits / 8)
+
+
+def prod(values, base=1):
+    """np.prod can overflow, so for sizes the product should be done in Python.
+
+    Even though np.prod type promotes to int64, it can still overflow in which
+    case the negative value will pass the size check and OOM when attempting to
+    actually allocate the Tensor.
+    """
+    return functools.reduce(lambda x, y: int(x) * int(y), values, base)
+
+
+class FuzzedTensor:
+    def __init__(
+        self,
+        name: str,
+        size: tuple[str | int, ...],
+        steps: tuple[str | int, ...] | None = None,
+        probability_contiguous: float = 0.5,
+        min_elements: int | None = None,
+        max_elements: int | None = None,
+        max_allocation_bytes: int | None = None,
+        dim_parameter: str | None = None,
+        roll_parameter: str | None = None,
+        dtype=torch.float32,
+        cuda=False,
+        tensor_constructor: Callable | None = None
+    ) -> None:
+        """
+        Args:
+            name:
+                A string identifier for the generated Tensor.
+            size:
+                A tuple of integers or strings specifying the size of the generated
+                Tensor. String values will replaced with a concrete int during the
+                generation process, while ints are simply passed as literals.
+            steps:
+                An optional tuple with the same length as `size`. This indicates
+                that a larger Tensor should be allocated, and then sliced to
+                produce the generated Tensor. For instance, if size is (4, 8)
+                and steps is (1, 4), then a tensor `t` of size (4, 32) will be
+                created and then `t[:, ::4]` will be used. (Allowing one to test
+                Tensors with strided memory.)
+            probability_contiguous:
+                A number between zero and one representing the chance that the
+                generated Tensor has a contiguous memory layout. This is achieved by
+                randomly permuting the shape of a Tensor, calling `.contiguous()`,
+                and then permuting back. This is applied before `steps`, which can
+                also cause a Tensor to be non-contiguous.
+            min_elements:
+                The minimum number of parameters that this Tensor must have for a
+                set of parameters to be valid. (Otherwise they are resampled.)
+            max_elements:
+                Like `min_elements`, but setting an upper bound.
+            max_allocation_bytes:
+                Like `max_elements`, but for the size of Tensor that must be
+                allocated prior to slicing for `steps` (if applicable). For
+                example, a FloatTensor with size (1024, 1024) and steps (4, 4)
+                would have 1M elements, but would require a 64 MB allocation.
+            dim_parameter:
+                The length of `size` and `steps` will be truncated to this value.
+                This allows Tensors of varying dimensions to be generated by the
+                Fuzzer.
+            dtype:
+                The PyTorch dtype of the generated Tensor.
+            cuda:
+                Whether to place the Tensor on a GPU.
+            tensor_constructor:
+                Callable which will be used instead of the default Tensor
+                construction method. This allows the author to enforce properties
+                of the Tensor (e.g. it can only have certain values). The dtype and
+                concrete shape of the Tensor to be created will be passed, and
+                concrete values of all parameters will be passed as kwargs. Note
+                that transformations to the result (permuting, slicing) will be
+                performed by the Fuzzer; the tensor_constructor is only responsible
+                for creating an appropriately sized Tensor.
+        """
+        self._name = name
+        self._size = size
+        self._steps = steps
+        self._probability_contiguous = probability_contiguous
+        self._min_elements = min_elements
+        self._max_elements = max_elements
+        self._max_allocation_bytes = max_allocation_bytes
+        self._dim_parameter = dim_parameter
+        self._dtype = dtype
+        self._cuda = cuda
+        self._tensor_constructor = tensor_constructor
+
+    @property
+    def name(self):
+        return self._name
+
+    @staticmethod
+    def default_tensor_constructor(size, dtype, **kwargs):
+        if dtype.is_floating_point or dtype.is_complex:
+            return torch.rand(size=size, dtype=dtype, device="cpu")
+        else:
+            return torch.randint(1, 127, size=size, dtype=dtype, device="cpu")
+
+    def _make_tensor(self, params, state):
+        import numpy as np
+        size, steps, allocation_size = self._get_size_and_steps(params)
+        constructor = (
+            self._tensor_constructor or
+            self.default_tensor_constructor
+        )
+
+        raw_tensor = constructor(size=allocation_size, dtype=self._dtype, **params)
+        if self._cuda:
+            raw_tensor = raw_tensor.cuda()
+
+        # Randomly permute the Tensor and call `.contiguous()` to force re-ordering
+        # of the memory, and then permute it back to the original shape.
+        dim = len(size)
+        order = np.arange(dim)
+        if state.rand() > self._probability_contiguous:
+            while dim > 1 and np.all(order == np.arange(dim)):
+                order = state.permutation(raw_tensor.dim())
+
+            raw_tensor = raw_tensor.permute(tuple(order)).contiguous()
+            raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
+
+        slices = [slice(0, size * step, step) for size, step in zip(size, steps, strict=True)]
+        tensor = raw_tensor[tuple(slices)]
+
+        properties = {
+            "numel": int(tensor.numel()),
+            "order": order,
+            "steps": steps,
+            "is_contiguous": tensor.is_contiguous(),
+            "dtype": str(self._dtype),
+        }
+
+        return tensor, properties
+
+    def _get_size_and_steps(self, params):
+        dim = (
+            params[self._dim_parameter]
+            if self._dim_parameter is not None
+            else len(self._size)
+        )
+
+        def resolve(values, dim):
+            """Resolve values into concrete integers."""
+            values = tuple(params.get(i, i) for i in values)
+            if len(values) > dim:
+                values = values[:dim]
+            if len(values) < dim:
+                values = values + tuple(1 for _ in range(dim - len(values)))
+            return values
+
+        size = resolve(self._size, dim)
+        steps = resolve(self._steps or (), dim)
+        allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps, strict=True))
+        return size, steps, allocation_size
+
+    def satisfies_constraints(self, params) -> bool:
+        size, _, allocation_size = self._get_size_and_steps(params)
+        # Product is computed in Python to avoid integer overflow.
+        num_elements = prod(size)
+        if num_elements < 0:
+            raise AssertionError("Computed number of elements is negative")
+
+        allocation_bytes = prod(allocation_size, base=dtype_size(self._dtype))
+
+        def nullable_greater(left, right):
+            if left is None or right is None:
+                return False
+            return left > right
+
+        return not any((
+            nullable_greater(num_elements, self._max_elements),
+            nullable_greater(self._min_elements, num_elements),
+            nullable_greater(allocation_bytes, self._max_allocation_bytes),
+        ))
+
+
+class Fuzzer:
+    def __init__(
+        self,
+        parameters: list[FuzzedParameter | list[FuzzedParameter]],
+        tensors: list[FuzzedTensor | list[FuzzedTensor]],
+        constraints: list[Callable] | None = None,
+        seed: int | None = None
+    ) -> None:
+        """
+        Args:
+            parameters:
+                List of FuzzedParameters which provide specifications
+                for generated parameters. Iterable elements will be
+                unpacked, though arbitrary nested structures will not.
+            tensors:
+                List of FuzzedTensors which define the Tensors which
+                will be created each step based on the parameters for
+                that step. Iterable elements will be unpacked, though
+                arbitrary nested structures will not.
+            constraints:
+                List of callables. They will be called with params
+                as kwargs, and if any of them return False the current
+                set of parameters will be rejected.
+            seed:
+                Seed for the RandomState used by the Fuzzer. This will
+                also be used to set the PyTorch random seed so that random
+                ops will create reproducible Tensors.
+        """
+        import numpy as np
+        if seed is None:
+            seed = int(np.random.RandomState().randint(0, 2 ** 32 - 1, dtype=np.int64))
+        self._seed = seed
+        self._parameters = Fuzzer._unpack(parameters, FuzzedParameter)
+        self._tensors = Fuzzer._unpack(tensors, FuzzedTensor)
+        self._constraints = constraints or ()
+
+        p_names = {p.name for p in self._parameters}
+        t_names = {t.name for t in self._tensors}
+        name_overlap = p_names.intersection(t_names)
+        if name_overlap:
+            raise ValueError(f"Duplicate names in parameters and tensors: {name_overlap}")
+
+        self._rejections = 0
+        self._total_generated = 0
+
+    @staticmethod
+    def _unpack(values, cls):
+        return tuple(it.chain.from_iterable(
+            [[i] if isinstance(i, cls) else i for i in values]
+        ))
+
+    def take(self, n):
+        import numpy as np
+        state = np.random.RandomState(self._seed)
+        torch.manual_seed(state.randint(low=0, high=2 ** 63, dtype=np.int64))
+        for _ in range(n):
+            params = self._generate(state)
+            tensors = {}
+            tensor_properties = {}
+            for t in self._tensors:
+                tensor, properties = t._make_tensor(params, state)
+                tensors[t.name] = tensor
+                tensor_properties[t.name] = properties
+            yield tensors, tensor_properties, params
+
+    @property
+    def rejection_rate(self):
+        if not self._total_generated:
+            return 0.
+        return self._rejections / self._total_generated
+
+    def _generate(self, state):
+        strict_params: dict[str, float | int | ParameterAlias] = {}
+        for _ in range(1000):
+            candidate_params: dict[str, float | int | ParameterAlias] = {}
+            for p in self._parameters:
+                if p.strict:
+                    if p.name in strict_params:
+                        candidate_params[p.name] = strict_params[p.name]
+                    else:
+                        candidate_params[p.name] = p.sample(state)
+                        strict_params[p.name] = candidate_params[p.name]
+                else:
+                    candidate_params[p.name] = p.sample(state)
+
+            candidate_params = self._resolve_aliases(candidate_params)
+
+            self._total_generated += 1
+            if not all(f(candidate_params) for f in self._constraints):
+                self._rejections += 1
+                continue
+
+            if not all(t.satisfies_constraints(candidate_params) for t in self._tensors):
+                self._rejections += 1
+                continue
+
+            return candidate_params
+        raise ValueError("Failed to generate a set of valid parameters.")
+
+    @staticmethod
+    def _resolve_aliases(params):
+        params = dict(params)
+        alias_count = sum(isinstance(v, ParameterAlias) for v in params.values())
+
+        keys = list(params.keys())
+        while alias_count:
+            for k in keys:
+                v = params[k]
+                if isinstance(v, ParameterAlias):
+                    params[k] = params[v.alias_to]
+            alias_count_new = sum(isinstance(v, ParameterAlias) for v in params.values())
+            if alias_count == alias_count_new:
+                raise ValueError(f"ParameterAlias cycle detected\n{params}")
+
+            alias_count = alias_count_new
+
+        return params
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2a573b9b44fdc2ee3c5141d8badc75a2b484a78
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py
@@ -0,0 +1,126 @@
+# mypy: allow-untyped-defs
+from numbers import Number
+import torch
+from torch.utils.benchmark import FuzzedTensor
+import math
+
+class FuzzedSparseTensor(FuzzedTensor):
+    def __init__(
+        self,
+        name: str,
+        size: tuple[str | int, ...],
+        min_elements: int | None = None,
+        max_elements: int | None = None,
+        dim_parameter: str | None = None,
+        sparse_dim: str | None = None,
+        nnz: str | None = None,
+        density: str | None = None,
+        coalesced: str | None = None,
+        dtype=torch.float32,
+        cuda=False
+    ) -> None:
+        """
+        Args:
+            name:
+                A string identifier for the generated Tensor.
+            size:
+                A tuple of integers or strings specifying the size of the generated
+                Tensor. String values will replaced with a concrete int during the
+                generation process, while ints are simply passed as literals.
+            min_elements:
+                The minimum number of parameters that this Tensor must have for a
+                set of parameters to be valid. (Otherwise they are resampled.)
+            max_elements:
+                Like `min_elements`, but setting an upper bound.
+            dim_parameter:
+                The length of `size` will be truncated to this value.
+                This allows Tensors of varying dimensions to be generated by the
+                Fuzzer.
+            sparse_dim:
+                The number of sparse dimensions in a sparse tensor.
+            density:
+                This value allows tensors of varying sparsities to be generated by the Fuzzer.
+            coalesced:
+                The sparse tensor format permits uncoalesced sparse tensors,
+                where there may be duplicate coordinates in the indices.
+            dtype:
+                The PyTorch dtype of the generated Tensor.
+            cuda:
+                Whether to place the Tensor on a GPU.
+        """
+        super().__init__(name=name, size=size, min_elements=min_elements,
+                         max_elements=max_elements, dim_parameter=dim_parameter, dtype=dtype, cuda=cuda)
+        self._density = density
+        self._coalesced = coalesced
+        self._sparse_dim = sparse_dim
+
+    @staticmethod
+    def sparse_tensor_constructor(size, dtype, sparse_dim, nnz, is_coalesced):
+        """sparse_tensor_constructor creates a sparse tensor with coo format.
+
+        Note that when `is_coalesced` is False, the number of elements is doubled but the number of indices
+        represents the same amount of number of non zeros `nnz`, i.e, this is virtually the same tensor
+        with the same sparsity pattern. Moreover, most of the sparse operation will use coalesce() method
+        and what we want here is to get a sparse tensor with the same `nnz` even if this is coalesced or not.
+
+        In the other hand when `is_coalesced` is True the number of elements is reduced in the coalescing process
+        by an unclear amount however the probability to generate duplicates indices are low for most of the cases.
+        This decision was taken on purpose to maintain the construction cost as low as possible.
+        """
+        if isinstance(size, Number):
+            size = [size] * sparse_dim
+        if all(size[d] <= 0 for d in range(sparse_dim)) and nnz != 0:
+            raise AssertionError('invalid arguments')
+        v_size = [nnz] + list(size[sparse_dim:])
+        if dtype.is_floating_point:
+            v = torch.rand(size=v_size, dtype=dtype, device="cpu")
+        else:
+            v = torch.randint(1, 127, size=v_size, dtype=dtype, device="cpu")
+
+        i = torch.rand(sparse_dim, nnz, device="cpu")
+        i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
+        i = i.to(torch.long)
+
+        if not is_coalesced:
+            v = torch.cat([v, torch.randn_like(v)], 0)
+            i = torch.cat([i, i], 1)
+
+        x = torch.sparse_coo_tensor(i, v, torch.Size(size))
+        if is_coalesced:
+            x = x.coalesce()
+        return x
+
+    def _make_tensor(self, params, state):
+        # pyrefly: ignore [missing-attribute]
+        size, _, _ = self._get_size_and_steps(params)
+        density = params['density']
+        nnz = math.ceil(sum(size) * density)
+        if nnz > sum(size):
+            raise AssertionError('nnz cannot exceed total number of elements')
+
+        is_coalesced = params['coalesced']
+        sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
+        sparse_dim = min(sparse_dim, len(size))
+        # pyrefly: ignore [missing-attribute]
+        tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
+
+        # pyrefly: ignore [missing-attribute]
+        if self._cuda:
+            tensor = tensor.cuda()
+        sparse_dim = tensor.sparse_dim()
+        dense_dim = tensor.dense_dim()
+        is_hybrid = len(size[sparse_dim:]) > 0
+
+        properties = {
+            "numel": int(tensor.numel()),
+            "shape": tensor.size(),
+            "is_coalesced": tensor.is_coalesced(),
+            "density": density,
+            "sparsity": 1.0 - density,
+            "sparse_dim": sparse_dim,
+            "dense_dim": dense_dim,
+            "is_hybrid": is_hybrid,
+            # pyrefly: ignore [missing-attribute]
+            "dtype": str(self._dtype),
+        }
+        return tensor, properties
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..30b6f79c0b5aebca676035ac0b7c08cfce639b23
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp
@@ -0,0 +1,43 @@
+/* C++ template for Timer.timeit
+
+This template will be consumed by `cpp_jit.py`, and will replace:
+    `GLOBAL_SETUP_TEMPLATE_LOCATION`,
+    `SETUP_TEMPLATE_LOCATION`
+      and
+    `STMT_TEMPLATE_LOCATION`
+sections with user provided statements.
+*/
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
+
+double timeit(int n) {
+  pybind11::gil_scoped_release no_gil;
+
+  // Setup
+  // SETUP_TEMPLATE_LOCATION
+
+  {
+    // Warmup
+    // STMT_TEMPLATE_LOCATION
+  }
+
+  // Main loop
+  auto start_time = std::chrono::high_resolution_clock::now();
+  for (const auto loop_idx : c10::irange(n)) {
+    (void)loop_idx;
+    // STMT_TEMPLATE_LOCATION
+  }
+  auto end_time = std::chrono::high_resolution_clock::now();
+  return std::chrono::duration(end_time - start_time).count();
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("timeit", &timeit);
+}
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f131261b8f36d08e4d9ef87605f379c4215d63ea
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py
@@ -0,0 +1,533 @@
+"""Timer class based on the timeit.Timer class, but torch aware."""
+import enum
+import timeit
+import textwrap
+from typing import overload, Any, NoReturn
+from collections.abc import Callable
+
+import torch
+from torch.utils.benchmark.utils import common, cpp_jit
+from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType
+from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface
+
+
+__all__ = ["Timer", "timer", "Language"]
+
+
+if torch.accelerator.is_available():
+    def timer() -> float:
+        torch.accelerator.synchronize()
+        return timeit.default_timer()
+else:
+    timer = timeit.default_timer
+
+
+class Language(enum.Enum):
+    PYTHON = 0
+    CPP = 1
+
+
+class CPPTimer:
+    def __init__(
+        self,
+        stmt: str,
+        setup: str,
+        global_setup: str,
+        timer: Callable[[], float],
+        globals: dict[str, Any],
+    ) -> None:
+        if timer is not timeit.default_timer:
+            raise NotImplementedError(
+                "PyTorch was built with accelerators and an accelerator is present; however "
+                "Timer does not yet support accelerator measurements. If your "
+                "code is CPU only, pass `timer=timeit.default_timer` to the "
+                "Timer's constructor to indicate this. (Note that this will "
+                "produce incorrect results if an accelerator is in fact used, as "
+                "Timer will not synchronize the accelerator.)"
+            )
+
+        if globals:
+            raise ValueError("C++ timing does not support globals.")
+
+        self._stmt: str = textwrap.dedent(stmt)
+        self._setup: str = textwrap.dedent(setup)
+        self._global_setup: str = textwrap.dedent(global_setup)
+        self._timeit_module: TimeitModuleType | None = None
+
+    def timeit(self, number: int) -> float:
+        if self._timeit_module is None:
+            self._timeit_module = cpp_jit.compile_timeit_template(
+                stmt=self._stmt,
+                setup=self._setup,
+                global_setup=self._global_setup,
+            )
+
+        return self._timeit_module.timeit(number)
+
+
+class Timer:
+    """Helper class for measuring execution time of PyTorch statements.
+
+    For a full tutorial on how to use this class, see:
+    https://pytorch.org/tutorials/recipes/recipes/benchmark.html
+
+    The PyTorch Timer is based on `timeit.Timer` (and in fact uses
+    `timeit.Timer` internally), but with several key differences:
+
+    1) Runtime aware:
+        Timer will perform warmups (important as some elements of PyTorch are
+        lazily initialized), set threadpool size so that comparisons are
+        apples-to-apples, and synchronize asynchronous accelerator functions when
+        necessary.
+
+    2) Focus on replicates:
+        When measuring code, and particularly complex kernels / models,
+        run-to-run variation is a significant confounding factor. It is
+        expected that all measurements should include replicates to quantify
+        noise and allow median computation, which is more robust than mean.
+        To that effect, this class deviates from the `timeit` API by
+        conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`.
+        (Exact algorithms are discussed in method docstrings.) The `timeit`
+        method is replicated for cases where an adaptive strategy is not
+        desired.
+
+    3) Optional metadata:
+        When defining a Timer, one can optionally specify `label`, `sub_label`,
+        `description`, and `env`. (Defined later) These fields are included in
+        the representation of result object and by the `Compare` class to group
+        and display results for comparison.
+
+    4) Instruction counts
+        In addition to wall times, Timer can run a statement under Callgrind
+        and report instructions executed.
+
+    Directly analogous to `timeit.Timer` constructor arguments:
+
+        `stmt`, `setup`, `timer`, `globals`
+
+    PyTorch Timer specific constructor arguments:
+
+        `label`, `sub_label`, `description`, `env`, `num_threads`
+
+    Args:
+        stmt: Code snippet to be run in a loop and timed.
+
+        setup: Optional setup code. Used to define variables used in `stmt`
+
+        global_setup: (C++ only)
+            Code which is placed at the top level of the file for things like
+            `#include` statements.
+
+        timer:
+            Callable which returns the current time. If PyTorch was built
+            without accelerators or there is no accelerator present, this defaults to
+            `timeit.default_timer`; otherwise it will synchronize accelerators before
+            measuring the time.
+
+        globals:
+            A dict which defines the global variables when `stmt` is being
+            executed. This is the other method for providing variables which
+            `stmt` needs.
+
+        label:
+            String which summarizes `stmt`. For instance, if `stmt` is
+            "torch.nn.functional.relu(torch.add(x, 1, out=out))"
+            one might set label to "ReLU(x + 1)" to improve readability.
+
+        sub_label:
+            Provide supplemental information to disambiguate measurements
+            with identical stmt or label. For instance, in our example
+            above sub_label might be "float" or "int", so that it is easy
+            to differentiate:
+            "ReLU(x + 1): (float)"
+
+            "ReLU(x + 1): (int)"
+            when printing Measurements or summarizing using `Compare`.
+
+        description:
+            String to distinguish measurements with identical label and
+            sub_label. The principal use of `description` is to signal to
+            `Compare` the columns of data. For instance one might set it
+            based on the input size  to create a table of the form: ::
+
+                                        | n=1 | n=4 | ...
+                                        ------------- ...
+                ReLU(x + 1): (float)    | ... | ... | ...
+                ReLU(x + 1): (int)      | ... | ... | ...
+
+
+            using `Compare`. It is also included when printing a Measurement.
+
+        env:
+            This tag indicates that otherwise identical tasks were run in
+            different environments, and are therefore not equivalent, for
+            instance when A/B testing a change to a kernel. `Compare` will
+            treat Measurements with different `env` specification as distinct
+            when merging replicate runs.
+
+        num_threads:
+            The size of the PyTorch threadpool when executing `stmt`. Single
+            threaded performance is important as both a key inference workload
+            and a good indicator of intrinsic algorithmic efficiency, so the
+            default is set to one. This is in contrast to the default PyTorch
+            threadpool size which tries to utilize all cores.
+    """
+
+    _timer_cls: type[TimerClass] = timeit.Timer
+
+    def __init__(
+        self,
+        stmt: str = "pass",
+        setup: str = "pass",
+        global_setup: str = "",
+        timer: Callable[[], float] = timer,
+        globals: dict[str, Any] | None = None,
+        label: str | None = None,
+        sub_label: str | None = None,
+        description: str | None = None,
+        env: str | None = None,
+        num_threads: int = 1,
+        language: Language | str = Language.PYTHON,
+    ) -> None:
+        if not isinstance(stmt, str):
+            raise ValueError("Currently only a `str` stmt is supported.")
+
+        # We copy `globals` to prevent mutations from leaking.
+        # (For instance, `eval` adds the `__builtins__` key)
+        self._globals = dict(globals or {})
+
+        timer_kwargs = {}
+        if language in (Language.PYTHON, "py", "python"):
+            # Include `torch` if not specified as a convenience feature.
+            self._globals.setdefault("torch", torch)
+            self._language: Language = Language.PYTHON
+            if global_setup:
+                raise ValueError(
+                    f"global_setup is C++ only, got `{global_setup}`. Most "
+                    "likely this code can simply be moved to `setup`."
+                )
+
+        elif language in (Language.CPP, "cpp", "c++"):
+            if self._timer_cls is not timeit.Timer:
+                raise AssertionError("_timer_cls has already been swapped.")
+            self._timer_cls = CPPTimer
+            setup = ("" if setup == "pass" else setup)
+            self._language = Language.CPP
+            timer_kwargs["global_setup"] = global_setup
+
+        else:
+            raise ValueError(f"Invalid language `{language}`.")
+
+        # Convenience adjustment so that multi-line code snippets defined in
+        # functions do not IndentationError (Python) or look odd (C++). The
+        # leading newline removal is for the initial newline that appears when
+        # defining block strings. For instance:
+        #   textwrap.dedent("""
+        #     print("This is a stmt")
+        #   """)
+        # produces '\nprint("This is a stmt")\n'.
+        #
+        # Stripping this down to 'print("This is a stmt")' doesn't change
+        # what gets executed, but it makes __repr__'s nicer.
+        stmt = textwrap.dedent(stmt)
+        stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip()
+        setup = textwrap.dedent(setup)
+        setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
+
+        # pyrefly: ignore [bad-instantiation]
+        self._timer = self._timer_cls(
+            stmt=stmt,
+            setup=setup,
+            timer=timer,
+            globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals),
+            **timer_kwargs,
+        )
+        self._task_spec = common.TaskSpec(
+            stmt=stmt,
+            setup=setup,
+            global_setup=global_setup,
+            label=label,
+            sub_label=sub_label,
+            description=description,
+            env=env,
+            num_threads=num_threads,
+        )
+
+    def _timeit(self, number: int) -> float:
+        # Even calling a timer in C++ takes ~50 ns, so no real operation should
+        # take less than 1 ns. (And this prevents divide by zero errors.)
+        return max(self._timer.timeit(number), 1e-9)
+
+    def timeit(self, number: int = 1000000) -> common.Measurement:
+        """Mirrors the semantics of timeit.Timer.timeit().
+
+        Execute the main statement (`stmt`) `number` times.
+        https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit
+        """
+        with common.set_torch_threads(self._task_spec.num_threads):
+            # Warmup
+            self._timeit(number=max(int(number // 100), 2))
+
+            return common.Measurement(
+                number_per_run=number,
+                raw_times=[self._timeit(number=number)],
+                task_spec=self._task_spec
+            )
+
+    def repeat(self, repeat: int = -1, number: int = -1) -> None:
+        raise NotImplementedError("See `Timer.blocked_autorange.`")
+
+    def autorange(self, callback: Callable[[int, float], NoReturn] | None = None) -> None:
+        raise NotImplementedError("See `Timer.blocked_autorange.`")
+
+    def _threaded_measurement_loop(
+        self,
+        number: int,
+        time_hook: Callable[[], float],
+        stop_hook: Callable[[list[float]], bool],
+        min_run_time: float,
+        max_run_time: float | None = None,
+        callback: Callable[[int, float], NoReturn] | None = None
+    ) -> list[float]:
+        total_time = 0.0
+        can_stop = False
+        times: list[float] = []
+        with common.set_torch_threads(self._task_spec.num_threads):
+            while (total_time < min_run_time) or (not can_stop):
+                time_spent = time_hook()
+                times.append(time_spent)
+                total_time += time_spent
+                if callback:
+                    callback(number, time_spent)
+                can_stop = stop_hook(times)
+                if max_run_time and total_time > max_run_time:
+                    break
+        return times
+
+    def _estimate_block_size(self, min_run_time: float) -> int:
+        with common.set_torch_threads(self._task_spec.num_threads):
+            # Estimate the block size needed for measurement to be negligible
+            # compared to the inner loop. This also serves as a warmup.
+            overhead = torch.tensor([self._timeit(0) for _ in range(5)]).median().item()
+            number = 1
+            while True:
+                time_taken = self._timeit(number)
+                relative_overhead = overhead / time_taken
+                if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000:
+                    break
+                if time_taken > min_run_time:
+                    break
+                # Avoid overflow in C++ pybind11 interface
+                if number * 10 > 2147483647:
+                    break
+                number *= 10
+        return number
+
+    def blocked_autorange(
+        self,
+        callback: Callable[[int, float], NoReturn] | None = None,
+        min_run_time: float = 0.2,
+    ) -> common.Measurement:
+        """Measure many replicates while keeping timer overhead to a minimum.
+
+        At a high level, blocked_autorange executes the following pseudo-code::
+
+            `setup`
+
+            total_time = 0
+            while total_time < min_run_time
+                start = timer()
+                for _ in range(block_size):
+                    `stmt`
+                total_time += (timer() - start)
+
+        Note the variable `block_size` in the inner loop. The choice of block
+        size is important to measurement quality, and must balance two
+        competing objectives:
+
+            1) A small block size results in more replicates and generally
+               better statistics.
+
+            2) A large block size better amortizes the cost of `timer`
+               invocation, and results in a less biased measurement. This is
+               important because accelerator synchronization time is non-trivial
+               (order single to low double digit microseconds) and would
+               otherwise bias the measurement.
+
+        blocked_autorange sets block_size by running a warmup period,
+        increasing block size until timer overhead is less than 0.1% of
+        the overall computation. This value is then used for the main
+        measurement loop.
+
+        Returns:
+            A `Measurement` object that contains measured runtimes and
+            repetition counts, and can be used to compute statistics.
+            (mean, median, etc.)
+        """
+        number = self._estimate_block_size(min_run_time)
+
+        def time_hook() -> float:
+            return self._timeit(number)
+
+        def stop_hook(times: list[float]) -> bool:
+            return True
+
+        times = self._threaded_measurement_loop(
+            number, time_hook, stop_hook,
+            min_run_time=min_run_time,
+            callback=callback)
+
+        return common.Measurement(
+            number_per_run=number,
+            raw_times=times,
+            task_spec=self._task_spec
+        )
+
+    def adaptive_autorange(
+            self,
+            threshold: float = 0.1,
+            *,
+            min_run_time: float = 0.01,
+            max_run_time: float = 10.0,
+            callback: Callable[[int, float], NoReturn] | None = None,
+    ) -> common.Measurement:
+        """Similar to `blocked_autorange` but also checks for variablility in measurements
+        and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached.
+
+
+        At a high level, adaptive_autorange executes the following pseudo-code::
+
+            `setup`
+
+            times = []
+            while times.sum < max_run_time
+                start = timer()
+                for _ in range(block_size):
+                    `stmt`
+                times.append(timer() - start)
+
+                enough_data = len(times)>3 and times.sum > min_run_time
+                small_iqr=times.iqr/times.mean float:
+            return self._timeit(number)
+
+        def stop_hook(times: list[float]) -> bool:
+            if len(times) > 3:
+                return common.Measurement(
+                    number_per_run=number,
+                    raw_times=times,
+                    task_spec=self._task_spec
+                ).meets_confidence(threshold=threshold)
+            return False
+        times = self._threaded_measurement_loop(
+            number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback)
+
+        return common.Measurement(
+            number_per_run=number,
+            raw_times=times,
+            task_spec=self._task_spec
+        )
+
+    @overload
+    def collect_callgrind(
+        self,
+        number: int,
+        *,
+        repeats: None,
+        collect_baseline: bool,
+        retain_out_file: bool,
+    ) -> valgrind_timer_interface.CallgrindStats:
+        ...
+
+    @overload
+    def collect_callgrind(
+        self,
+        number: int,
+        *,
+        repeats: int,
+        collect_baseline: bool,
+        retain_out_file: bool,
+    ) -> tuple[valgrind_timer_interface.CallgrindStats, ...]:
+        ...
+
+    def collect_callgrind(
+        self,
+        number: int = 100,
+        *,
+        repeats: int | None = None,
+        collect_baseline: bool = True,
+        retain_out_file: bool = False,
+    ) -> Any:
+        """Collect instruction counts using Callgrind.
+
+        Unlike wall times, instruction counts are deterministic
+        (modulo non-determinism in the program itself and small amounts of
+        jitter from the Python interpreter.) This makes them ideal for detailed
+        performance analysis. This method runs `stmt` in a separate process
+        so that Valgrind can instrument the program. Performance is severely
+        degraded due to the instrumentation, however this is ameliorated by
+        the fact that a small number of iterations is generally sufficient to
+        obtain good measurements.
+
+        In order to use this method `valgrind`, `callgrind_control`, and
+        `callgrind_annotate` must be installed.
+
+        Because there is a process boundary between the caller (this process)
+        and the `stmt` execution, `globals` cannot contain arbitrary in-memory
+        data structures. (Unlike timing methods) Instead, globals are
+        restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules
+        to reduce the surprise factor from serialization and subsequent
+        deserialization. The `GlobalsBridge` class provides more detail on this
+        subject. Take particular care with nn.Modules: they rely on pickle and
+        you may need to add an import to `setup` for them to transfer properly.
+
+        By default, a profile for an empty statement will be collected and
+        cached to indicate how many instructions are from the Python loop which
+        drives `stmt`.
+
+        Returns:
+            A `CallgrindStats` object which provides instruction counts and
+            some basic facilities for analyzing and manipulating results.
+        """
+        if not isinstance(self._task_spec.stmt, str):
+            raise ValueError("`collect_callgrind` currently only supports string `stmt`")
+
+        if repeats is not None and repeats < 1:
+            raise ValueError("If specified, `repeats` must be >= 1")
+
+        # Check that the statement is valid. It doesn't guarantee success, but it's much
+        # simpler and quicker to raise an exception for a faulty `stmt` or `setup` in
+        # the parent process rather than the valgrind subprocess.
+        self._timeit(1)
+        is_python = (self._language == Language.PYTHON)
+        if not is_python and self._globals:
+            raise AssertionError("_timer globals are only supported for Python timers")
+        result = valgrind_timer_interface.wrapper_singleton().collect_callgrind(
+            task_spec=self._task_spec,
+            globals=self._globals,
+            number=number,
+            repeats=repeats or 1,
+            collect_baseline=collect_baseline and is_python,
+            is_python=is_python,
+            retain_out_file=retain_out_file,
+        )
+
+        return (result[0] if repeats is None else result)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..536efea67454e0ee8c4d8f101777b7530cb99235
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d009b1896b83727c0ca65cb8b202ea661fbb5e1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
new file mode 100644
index 0000000000000000000000000000000000000000..f078cc82b95daf94d2bea51f1e1b1a8c12daea23
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
@@ -0,0 +1,129 @@
+
+/*
+   ----------------------------------------------------------------
+
+   Notice that the following BSD-style license applies to this one
+   file (callgrind.h) only.  The rest of Valgrind is licensed under the
+   terms of the GNU General Public License, version 2, unless
+   otherwise indicated.  See the COPYING file in the source
+   distribution for details.
+
+   ----------------------------------------------------------------
+
+   This file is part of callgrind, a valgrind tool for cache simulation
+   and call tree tracing.
+
+   Copyright (C) 2003-2017 Josef Weidendorfer.  All rights reserved.
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   1. Redistributions of source code must retain the above copyright
+      notice, this list of conditions and the following disclaimer.
+
+   2. The origin of this software must not be misrepresented; you must
+      not claim that you wrote the original software.  If you use this
+      software in a product, an acknowledgment in the product
+      documentation would be appreciated but is not required.
+
+   3. Altered source versions must be plainly marked as such, and must
+      not be misrepresented as being the original software.
+
+   4. The name of the author may not be used to endorse or promote
+      products derived from this software without specific prior written
+      permission.
+
+   THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
+   OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+   WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+   ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+   DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
+   GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+   INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+   WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+   ----------------------------------------------------------------
+
+   Notice that the above BSD-style license applies to this one file
+   (callgrind.h) only.  The entire rest of Valgrind is licensed under
+   the terms of the GNU General Public License, version 2.  See the
+   COPYING file in the source distribution for details.
+
+   ----------------------------------------------------------------
+*/
+
+#ifndef __CALLGRIND_H
+#define __CALLGRIND_H
+
+#include "valgrind.h"
+
+/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !!
+   This enum comprises an ABI exported by Valgrind to programs
+   which use client requests.  DO NOT CHANGE THE ORDER OF THESE
+   ENTRIES, NOR DELETE ANY -- add new ones at the end.
+
+   The identification ('C','T') for Callgrind has historical
+   reasons: it was called "Calltree" before. Besides, ('C','G') would
+   clash with cachegrind.
+ */
+
+typedef
+   enum {
+      VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'),
+      VG_USERREQ__ZERO_STATS,
+      VG_USERREQ__TOGGLE_COLLECT,
+      VG_USERREQ__DUMP_STATS_AT,
+      VG_USERREQ__START_INSTRUMENTATION,
+      VG_USERREQ__STOP_INSTRUMENTATION
+   } Vg_CallgrindClientRequest;
+
+/* Dump current state of cost centers, and zero them afterwards */
+#define CALLGRIND_DUMP_STATS                                    \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS,       \
+                                  0, 0, 0, 0, 0)
+
+/* Dump current state of cost centers, and zero them afterwards.
+   The argument is appended to a string stating the reason which triggered
+   the dump. This string is written as a description field into the
+   profile data dump. */
+#define CALLGRIND_DUMP_STATS_AT(pos_str)                        \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS_AT,    \
+                                  pos_str, 0, 0, 0, 0)
+
+/* Zero cost centers */
+#define CALLGRIND_ZERO_STATS                                    \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__ZERO_STATS,       \
+                                  0, 0, 0, 0, 0)
+
+/* Toggles collection state.
+   The collection state specifies whether the happening of events
+   should be noted or if they are to be ignored. Events are noted
+   by increment of counters in a cost center */
+#define CALLGRIND_TOGGLE_COLLECT                                \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__TOGGLE_COLLECT,   \
+                                  0, 0, 0, 0, 0)
+
+/* Start full callgrind instrumentation if not already switched on.
+   When cache simulation is done, it will flush the simulated cache;
+   this will lead to an artificial cache warmup phase afterwards with
+   cache misses which would not have happened in reality. */
+#define CALLGRIND_START_INSTRUMENTATION                              \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__START_INSTRUMENTATION, \
+                                  0, 0, 0, 0, 0)
+
+/* Stop full callgrind instrumentation if not already switched off.
+   This flushes Valgrinds translation cache, and does no additional
+   instrumentation afterwards, which effectivly will run at the same
+   speed as the "none" tool (ie. at minimal slowdown).
+   Use this to bypass Callgrind aggregation for uninteresting code parts.
+   To start Callgrind in this mode to ignore the setup phase, use
+   the option "--instr-atstart=no". */
+#define CALLGRIND_STOP_INSTRUMENTATION                               \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STOP_INSTRUMENTATION,  \
+                                  0, 0, 0, 0, 0)
+
+#endif /* __CALLGRIND_H */
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..cd41f0de092f0b1488c8945edf2af80c6f9b596c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
@@ -0,0 +1,35 @@
+/* Used to collect profiles of old versions of PyTorch. */
+#include 
+#include 
+
+bool _valgrind_supported_platform() {
+#if defined(NVALGRIND)
+  return false;
+#else
+  return true;
+#endif
+}
+
+void _valgrind_toggle() {
+#if defined(NVALGRIND)
+  TORCH_CHECK(false, "Valgrind is not supported.");
+#else
+  CALLGRIND_TOGGLE_COLLECT;
+#endif
+}
+
+void _valgrind_toggle_and_dump_stats() {
+#if defined(NVALGRIND)
+  TORCH_CHECK(false, "Valgrind is not supported.");
+#else
+  // NB: See note in Module.cpp
+  CALLGRIND_TOGGLE_COLLECT;
+  CALLGRIND_DUMP_STATS;
+#endif
+}
+
+PYBIND11_MODULE(callgrind_bindings, m) {
+  m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
+  m.def("_valgrind_toggle", &_valgrind_toggle);
+  m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats);
+}
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..587685c7df7445b299c35462307f47cf6012a00d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
@@ -0,0 +1,68 @@
+/* C++ template for Timer.collect_callgrind
+
+This template will be consumed by `cpp_jit.py`, and will replace:
+    `GLOBAL_SETUP_TEMPLATE_LOCATION`,
+    `SETUP_TEMPLATE_LOCATION`
+      and
+    `STMT_TEMPLATE_LOCATION`
+sections with user provided statements.
+*/
+
+#include 
+#include 
+#include 
+
+#include 
+
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
+
+#if defined(NVALGRIND)
+static_assert(false);
+#endif
+
+int main(int argc, char* argv[]) {
+  // This file should only be called inside of `Timer`, so we can adopt a
+  // very simple and rigid argument parsing scheme.
+  TORCH_CHECK(argc == 9);
+  TORCH_CHECK(std::string(argv[1]) == "--number");
+  auto number = std::stoi(argv[2]);
+
+  TORCH_CHECK(
+      std::string(argv[3]) == "--number-warmup" ||
+      std::string(argv[3]) == "--number_warmup");
+  auto number_warmup = std::stoi(argv[4]);
+
+  TORCH_CHECK(std::string(argv[5]) == "--repeats");
+  auto repeats = std::stoi(argv[6]);
+
+  TORCH_CHECK(
+      std::string(argv[7]) == "--number-threads" ||
+      std::string(argv[7]) == "--number_threads");
+  auto number_threads = std::stoi(argv[8]);
+  torch::set_num_threads(number_threads);
+
+  // Setup
+  // SETUP_TEMPLATE_LOCATION
+
+  // Warmup
+  for (const auto i : c10::irange(number_warmup)) {
+    (void)i;
+    // STMT_TEMPLATE_LOCATION
+  }
+
+  // Main loop
+  for (const auto repeat : c10::irange(repeats)) {
+    (void)repeat;
+    CALLGRIND_TOGGLE_COLLECT;
+
+    for (const auto i : c10::irange(number)) {
+      (void)i;
+      // STMT_TEMPLATE_LOCATION
+    }
+
+    // NB: See note in Module.cpp
+    CALLGRIND_TOGGLE_COLLECT;
+    CALLGRIND_DUMP_STATS;
+  }
+}
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ecea8bbb5598db967e8213b5bbd9c0fd8562f3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -0,0 +1,919 @@
+"""Intermediate layer between `Timer` and `valgrind`."""
+import collections
+import enum
+import dataclasses
+import itertools as it
+import os
+import pickle
+import re
+import shutil
+import subprocess
+import sys
+import textwrap
+from typing import (
+    cast, Any, NamedTuple,
+    Union, TYPE_CHECKING)
+from collections.abc import Callable
+from collections.abc import Iterator
+
+import torch
+from torch.utils.benchmark.utils import common, cpp_jit
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType
+import operator
+
+
+__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
+
+
+if TYPE_CHECKING:
+    CompletedProcessType = subprocess.CompletedProcess[str]
+else:
+    CompletedProcessType = subprocess.CompletedProcess
+
+
+class FunctionCount(NamedTuple):
+    # TODO(#105471): Rename the count field
+    count: int  # type: ignore[assignment]
+    function: str
+
+
+@dataclasses.dataclass(repr=False, eq=False, frozen=True)
+class FunctionCounts:
+    """Container for manipulating Callgrind results.
+
+    It supports:
+        1) Addition and subtraction to combine or diff results.
+        2) Tuple-like indexing.
+        3) A `denoise` function which strips CPython calls which are known to
+           be non-deterministic and quite noisy.
+        4) Two higher order methods (`filter` and `transform`) for custom
+           manipulation.
+    """
+    _data: tuple[FunctionCount, ...]
+    inclusive: bool
+    truncate_rows: bool = True
+
+    # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines
+    # the print settings. This is simply to allow hermetic unit tests.
+    _linewidth: int | None = None
+
+    def __iter__(self) -> Iterator[FunctionCount]:
+        yield from self._data
+
+    def __len__(self) -> int:
+        return len(self._data)
+
+    def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]:
+        data: FunctionCount | tuple[FunctionCount, ...] = self._data[item]
+        return (
+            FunctionCounts(cast(tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False)
+            if isinstance(data, tuple) else data
+        )
+
+    def __repr__(self) -> str:
+        count_len = 0
+        for c, _ in self:
+            # Account for sign in string length.
+            count_len = max(count_len, len(str(c)) + int(c < 0))
+
+        lines = []
+        linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth
+        fn_str_len = max(linewidth - count_len - 4, 40)
+        for c, fn in self:
+            if len(fn) > fn_str_len:
+                left_len = int((fn_str_len - 5) // 2)
+                fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):]
+            lines.append(f"  {c:>{count_len}}  {fn}")
+
+        if self.truncate_rows and len(lines) > 18:
+            lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:]
+
+        if not self.inclusive:
+            lines.extend(["", f"Total: {self.sum()}"])
+
+        return "\n".join([super().__repr__()] + lines)
+
+    def __add__(
+        self,
+        other: "FunctionCounts",
+    ) -> "FunctionCounts":
+        return self._merge(other, lambda c: c)
+
+    def __sub__(
+        self,
+        other: "FunctionCounts",
+    ) -> "FunctionCounts":
+        return self._merge(other, operator.neg)
+
+    def __mul__(self, other: int | float) -> "FunctionCounts":
+        return self._from_dict({
+            fn: int(c * other) for c, fn in self._data
+        }, self.inclusive)
+
+    def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts":
+        """Apply `map_fn` to all of the function names.
+
+        This can be used to regularize function names (e.g. stripping irrelevant
+        parts of the file path), coalesce entries by mapping multiple functions
+        to the same name (in which case the counts are added together), etc.
+        """
+        counts: collections.defaultdict[str, int] = collections.defaultdict(int)
+        for c, fn in self._data:
+            counts[map_fn(fn)] += c
+
+        return self._from_dict(counts, self.inclusive)
+
+    def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts":
+        """Keep only the elements where `filter_fn` applied to function name returns True."""
+        return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive)
+
+    def sum(self) -> int:
+        return sum(c for c, _ in self)
+
+    def denoise(self) -> "FunctionCounts":
+        """Remove known noisy instructions.
+
+        Several instructions in the CPython interpreter are rather noisy. These
+        instructions involve unicode to dictionary lookups which Python uses to
+        map variable names. FunctionCounts is generally a content agnostic
+        container, however this is sufficiently important for obtaining
+        reliable results to warrant an exception."""
+        return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn)
+
+    def _merge(
+        self,
+        second: "FunctionCounts",
+        merge_fn: Callable[[int], int]
+    ) -> "FunctionCounts":
+        if self.inclusive != second.inclusive:
+            raise AssertionError("Cannot merge inclusive and exclusive counts.")
+        counts: collections.defaultdict[str, int] = collections.defaultdict(int)
+        for c, fn in self:
+            counts[fn] += c
+
+        for c, fn in second:
+            counts[fn] += merge_fn(c)
+
+        return self._from_dict(counts, self.inclusive)
+
+    @staticmethod
+    def _from_dict(counts: dict[str, int], inclusive: bool) -> "FunctionCounts":
+        flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c)
+        return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive)
+
+
+@dataclasses.dataclass(repr=False, eq=False, frozen=True)
+class CallgrindStats:
+    """Top level container for Callgrind results collected by Timer.
+
+    Manipulation is generally done using the FunctionCounts class, which is
+    obtained by calling `CallgrindStats.stats(...)`. Several convenience
+    methods are provided as well; the most significant is
+    `CallgrindStats.as_standardized()`.
+    """
+    task_spec: common.TaskSpec
+    number_per_run: int
+    built_with_debug_symbols: bool
+    baseline_inclusive_stats: FunctionCounts
+    baseline_exclusive_stats: FunctionCounts
+    stmt_inclusive_stats: FunctionCounts
+    stmt_exclusive_stats: FunctionCounts
+    stmt_callgrind_out: str | None
+
+    def __repr__(self) -> str:
+        base_stats = self.baseline_exclusive_stats
+        output = f"""
+{super().__repr__()}
+{self.task_spec.summarize()}
+  {'':>25}All{'':>10}Noisy symbols removed
+    Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12}
+    Baseline:     {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12}
+{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''}
+""".strip()
+        if not self.built_with_debug_symbols:
+            output += textwrap.dedent("""
+            Warning: PyTorch was not built with debug symbols.
+                     Source information may be limited. Rebuild with
+                     REL_WITH_DEB_INFO=1 for more detailed results.""")
+        return output
+
+    def stats(self, inclusive: bool = False) -> FunctionCounts:
+        """Returns detailed function counts.
+
+        Conceptually, the FunctionCounts returned can be thought of as a tuple
+        of (count, path_and_function_name) tuples.
+
+        `inclusive` matches the semantics of callgrind. If True, the counts
+        include instructions executed by children. `inclusive=True` is useful
+        for identifying hot spots in code; `inclusive=False` is useful for
+        reducing noise when diffing counts from two different runs. (See
+        CallgrindStats.delta(...) for more details)
+        """
+        return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats
+
+    def counts(self, *, denoise: bool = False) -> int:
+        """Returns the total number of instructions executed.
+
+        See `FunctionCounts.denoise()` for an explanation of the `denoise` arg.
+        """
+        stats = self.stmt_exclusive_stats
+        return (stats.denoise() if denoise else stats).sum()
+
+    # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563
+    def delta(
+        self,
+        other: "CallgrindStats",
+        inclusive: bool = False,
+    ) -> FunctionCounts:
+        """Diff two sets of counts.
+
+        One common reason to collect instruction counts is to determine the
+        the effect that a particular change will have on the number of instructions
+        needed to perform some unit of work. If a change increases that number, the
+        next logical question is "why". This generally involves looking at what part
+        if the code increased in instruction count. This function automates that
+        process so that one can easily diff counts on both an inclusive and
+        exclusive basis.
+        """
+        return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive)
+
+    def as_standardized(self) -> "CallgrindStats":
+        """Strip library names and some prefixes from function strings.
+
+        When comparing two different sets of instruction counts, on stumbling
+        block can be path prefixes. Callgrind includes the full filepath
+        when reporting a function (as it should). However, this can cause
+        issues when diffing profiles. If a key component such as Python
+        or PyTorch was built in separate locations in the two profiles, which
+        can result in something resembling::
+
+            23234231 /tmp/first_build_dir/thing.c:foo(...)
+             9823794 /tmp/first_build_dir/thing.c:bar(...)
+              ...
+               53453 .../aten/src/Aten/...:function_that_actually_changed(...)
+              ...
+             -9823794 /tmp/second_build_dir/thing.c:bar(...)
+            -23234231 /tmp/second_build_dir/thing.c:foo(...)
+
+        Stripping prefixes can ameliorate this issue by regularizing the
+        strings and causing better cancellation of equivalent call sites
+        when diffing.
+        """
+        def strip(stats: FunctionCounts) -> FunctionCounts:
+            transforms = (
+                # PyTorch may have been built in different locations.
+                (r"^.+build/\.\./", "build/../"),
+                (r"^.+/" + re.escape("build/aten/"), "build/aten/"),
+
+                # "Python" and "Objects" come from CPython.
+                (r"^.+/" + re.escape("Python/"), "Python/"),
+                (r"^.+/" + re.escape("Objects/"), "Objects/"),
+
+                # Strip library name. e.g. `libtorch.so`
+                (r"\s\[.+\]$", ""),
+            )
+
+            for before, after in transforms:
+                stats = stats.transform(lambda fn: re.sub(before, after, fn))
+
+            return stats
+
+        return CallgrindStats(
+            task_spec=self.task_spec,
+            number_per_run=self.number_per_run,
+            built_with_debug_symbols=self.built_with_debug_symbols,
+            baseline_inclusive_stats=strip(self.baseline_inclusive_stats),
+            baseline_exclusive_stats=strip(self.baseline_exclusive_stats),
+            stmt_inclusive_stats=strip(self.stmt_inclusive_stats),
+            stmt_exclusive_stats=strip(self.stmt_exclusive_stats),
+
+            # `as_standardized` will change symbol names, so the contents will
+            # no longer map directly to `callgrind.out`
+            stmt_callgrind_out=None,
+        )
+
+
+class Serialization(enum.Enum):
+    PICKLE = 0
+    TORCH = 1
+    TORCH_JIT = 2
+
+
+_GLOBALS_ALLOWED_TYPES: dict[Serialization, tuple[Any, ...]] = {
+    Serialization.PICKLE: (str, bytes, bool, int, float, complex),
+    Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule),
+    Serialization.TORCH: (torch.nn.Module,),
+}
+
+
+class CopyIfCallgrind:
+    """Signal that a global may be replaced with a deserialized copy.
+
+    See `GlobalsBridge` for why this matters.
+    """
+    def __init__(self, value: Any, *, setup: str | None = None) -> None:
+        for method, supported_types in _GLOBALS_ALLOWED_TYPES.items():
+            if any(isinstance(value, t) for t in supported_types):
+                self._value: Any = value
+                self._setup: str | None = setup
+                self._serialization: Serialization = method
+                break
+        else:
+            supported_str = "\n".join([
+                getattr(t, "__name__", repr(t))
+                for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())])
+
+            raise ValueError(
+                f"Unsupported type: {type(value)}\n"
+                f"`collect_callgrind` restricts globals to the following types:\n"
+                f"{textwrap.indent(supported_str, '  ')}"
+            )
+
+    @property
+    def value(self) -> Any:
+        return self._value
+
+    @property
+    def setup(self) -> str | None:
+        return self._setup
+
+    @property
+    def serialization(self) -> Serialization:
+        return self._serialization
+
+    @staticmethod
+    def unwrap_all(globals: dict[str, Any]) -> dict[str, Any]:
+        return {
+            k: (v.value if isinstance(v, CopyIfCallgrind) else v)
+            for k, v in globals.items()
+        }
+
+
+class GlobalsBridge:
+    """Handle the transfer of (certain) globals when collecting Callgrind statistics.
+
+    Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to
+                  work with `Timer.collect_callgrind`.
+
+    Consider the following code snippet:
+    ```
+        import pickle
+        import timeit
+
+        class Counter:
+            value = 0
+
+            def __call__(self):
+                self.value += 1
+
+        counter = Counter()
+        timeit.Timer("counter()", globals={"counter": counter}).timeit(10)
+        print(counter.value)  # 10
+
+        timeit.Timer(
+            "counter()",
+            globals={"counter": pickle.loads(pickle.dumps(counter))}
+        ).timeit(20)
+        print(counter.value)  # Still 10
+    ```
+
+    In the first case, `stmt` is executed using the objects in `globals`;
+    however, the addition of serialization and deserialization changes the
+    semantics and may meaningfully change behavior.
+
+    This is a practical consideration when collecting Callgrind statistics.
+    Unlike `exec` based execution (which `timeit` uses under the hood) which
+    can share in-memory data structures with the caller, Callgrind collection
+    requires an entirely new process in order to run under Valgrind. This means
+    that any data structures used for statement execution will have to be
+    serialized and deserialized in the subprocess.
+
+    In order to avoid surprising semantics from (user invisible) process
+    boundaries, what can be passed through `globals` is severely restricted
+    for `Timer.collect_callgrind`. It is expected that most setup should be
+    achievable (albeit perhaps less ergonomically) by passing a `setup`
+    string.
+
+    There are, however, exceptions. One such class are TorchScripted functions.
+    Because they require a concrete file with source code it is not possible
+    to define them using a `setup` string. Another group are torch.nn.Modules,
+    whose construction can be complex and prohibitively cumbersome to coerce
+    into a `setup` string. Finally, most builtin types are sufficiently well
+    behaved and sufficiently common to warrant allowing as well. (e.g.
+    `globals={"n": 1}` is very convenient.)
+
+    Fortunately, all have well defined serialization semantics. This class
+    is responsible for enabling the Valgrind subprocess to use elements in
+    `globals` so long as they are an allowed type.
+
+    Caveats:
+        The user is required to acknowledge this serialization by wrapping
+        elements in `globals` with `CopyIfCallgrind`.
+
+        While ScriptFunction and ScriptModule are expected to save and load
+        quite robustly, it is up to the user to ensure that an nn.Module can
+        un-pickle successfully.
+
+        `torch.Tensor` and `np.ndarray` are deliberately excluded. The
+        serialization/deserialization process perturbs the representation of a
+        tensor in ways that could result in incorrect measurements. For example,
+        if a tensor lives in pinned CPU memory, this fact would not be preserved
+        by a dump, and that will in turn change the performance of certain CUDA
+        operations.
+    """
+
+    def __init__(self, globals: dict[str, Any], data_dir: str) -> None:
+        self._globals: dict[str, CopyIfCallgrind] = {}
+        self._data_dir = data_dir
+        if not os.path.exists(data_dir):
+            os.mkdir(data_dir)
+
+        if globals.get("torch", torch) is not torch:
+            raise ValueError("`collect_callgrind` does not support mocking out `torch`.")
+
+        for name, value in globals.items():
+            if name in ("torch", "__builtins__"):
+                # Torch will be imported by the collection script, and
+                # __builtins__ is added by Timer.
+                continue
+
+            if not isinstance(value, CopyIfCallgrind):
+                raise ValueError(
+                    "`collect_callgrind` requires that globals be wrapped in "
+                    "`CopyIfCallgrind` so that serialization is explicit."
+                )
+
+            self._globals[name] = value
+
+    def construct(self) -> str:
+        load_lines = []
+        for name, wrapped_value in self._globals.items():
+            if wrapped_value.setup is not None:
+                # pyrefly: ignore [bad-argument-type]
+                load_lines.append(textwrap.dedent(wrapped_value.setup))
+
+            if wrapped_value.serialization == Serialization.PICKLE:
+                path = os.path.join(self._data_dir, f"{name}.pkl")
+                load_lines.append(
+                    # pyrefly: ignore [bad-argument-type]
+                    f"with open({repr(path)}, 'rb') as f:\n    {name} = pickle.load(f)")
+                with open(path, "wb") as f:
+                    pickle.dump(wrapped_value.value, f)
+
+            elif wrapped_value.serialization == Serialization.TORCH:
+                path = os.path.join(self._data_dir, f"{name}.pt")
+                # TODO: Figure out if we can use torch.serialization.add_safe_globals here
+                # Using weights_only=False after the change in
+                # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
+                # pyrefly: ignore [bad-argument-type]
+                load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
+                torch.save(wrapped_value.value, path)
+
+            elif wrapped_value.serialization == Serialization.TORCH_JIT:
+                path = os.path.join(self._data_dir, f"{name}.pt")
+                # pyrefly: ignore [bad-argument-type]
+                load_lines.append(f"{name} = torch.jit.load({repr(path)})")
+                with open(path, "wb") as f:
+                    torch.jit.save(wrapped_value.value, f)  # type: ignore[no-untyped-call]
+
+            else:
+                raise NotImplementedError(
+                    f"Unknown serialization method: {wrapped_value.serialization}")
+
+        return "\n".join(load_lines)
+
+
+class _ValgrindWrapper:
+    def __init__(self) -> None:
+        self._bindings_module: CallgrindModuleType | None = None
+        valgrind_symbols = (
+            "_valgrind_supported_platform",
+            "_valgrind_toggle",
+            "_valgrind_toggle_and_dump_stats",
+        )
+        if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols):
+            self._supported_platform: bool = torch._C._valgrind_supported_platform()
+
+        else:
+            print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
+            self._bindings_module = cpp_jit.get_compat_bindings()
+            if not all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols):
+                raise AssertionError("JIT-compiled callgrind bindings are missing required symbols")
+            self._supported_platform = self._bindings_module._valgrind_supported_platform()
+
+        self._commands_available: dict[str, bool] = {}
+        if self._supported_platform:
+            # Only bother checking on supported platforms.
+            for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"):
+                self._commands_available[cmd] = not subprocess.run(
+                    ["which", cmd],
+                    capture_output=True,
+                    check=False,
+                ).returncode
+
+        self._build_type: str | None = None
+        build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show())  # type: ignore[no-untyped-call]
+        if build_search is not None:
+            self._build_type = build_search.groups()[0].split(",")[0]
+
+    def _validate(self) -> None:
+        if not self._supported_platform:
+            raise OSError("Valgrind is not supported on this platform.")
+
+        missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available]
+        if missing_cmds:
+            raise OSError("Missing: " + ", ".join(missing_cmds))
+
+    def collect_callgrind(
+        self,
+        task_spec: common.TaskSpec,
+        globals: dict[str, Any],
+        *,
+        number: int,
+        repeats: int,
+        collect_baseline: bool,
+        is_python: bool,
+        retain_out_file: bool,
+    ) -> tuple[CallgrindStats, ...]:
+        """Collect stats, and attach a reference run which can be used to filter interpreter overhead."""
+        self._validate()
+        if not is_python and collect_baseline:
+            raise AssertionError("collect_baseline is only supported for Python timers")
+
+        *task_stats, baseline_stats = self._invoke(
+            task_spec=task_spec,
+            globals=globals,
+            number=number,
+            repeats=repeats,
+            collect_baseline=collect_baseline,
+            is_python=is_python,
+            retain_out_file=retain_out_file,
+        )
+        if len(task_stats) != repeats:
+            raise AssertionError("Unexpected number of task stats returned from _invoke")
+
+        return tuple(
+            CallgrindStats(
+                task_spec=task_spec,
+                number_per_run=number,
+                built_with_debug_symbols=self._build_type == "RelWithDebInfo",
+                baseline_inclusive_stats=baseline_stats[0],
+                baseline_exclusive_stats=baseline_stats[1],
+                stmt_inclusive_stats=stmt_inclusive_stats,
+                stmt_exclusive_stats=stmt_exclusive_stats,
+                stmt_callgrind_out=out_contents,
+            )
+            for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats
+        )
+
+    def _invoke(
+        self,
+        *,
+        task_spec: common.TaskSpec,
+        globals: dict[str, Any],
+        number: int,
+        repeats: int,
+        collect_baseline: bool,
+        is_python: bool,
+        retain_out_file: bool,
+    ) -> tuple[tuple[FunctionCounts, FunctionCounts, str | None], ...]:
+        """Core invocation method for Callgrind collection.
+
+        Valgrind operates by effectively replacing the CPU with an emulated
+        version which allows it to instrument any code at the cost of severe
+        performance degradation. This has the practical effect that in order
+        to collect Callgrind statistics, a new process has to be created
+        running under `valgrind`. The steps for this process are:
+
+        1) Create a scratch directory.
+        2) Codegen a run script. (_ValgrindWrapper._construct_script)
+            Inside the run script:
+                * Validate that Python and torch match the parent process
+                * Validate that it is indeed running under valgrind
+                * Execute `setup` and warm up `stmt`
+                * Begin collecting stats
+                * Run the `stmt` loop
+                * Stop collecting stats
+        3) Parse the run results.
+        4) Cleanup the scratch directory.
+        """
+        working_dir = common._make_temp_dir(prefix="callgrind")
+        data_dir = os.path.join(working_dir, "data")
+        script_file = os.path.join(working_dir, "timer_callgrind.py")
+        callgrind_out = os.path.join(working_dir, "callgrind.out")
+        error_log = os.path.join(working_dir, "error.txt")
+        stat_log = os.path.join(working_dir, "callgrind_stat.txt")
+        stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log")
+
+        def run(args: list[str], **kwargs: Any) -> tuple[CompletedProcessType, str]:
+            # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/
+            with open(stdout_stderr_log, "wb") as f_stdout_stderr:
+                invocation = subprocess.run(
+                    args,
+                    stdout=f_stdout_stderr,
+                    stderr=subprocess.STDOUT,
+                    **kwargs,
+                )
+                with open(stdout_stderr_log) as f:
+                    return invocation, f.read()
+
+        try:
+            if is_python:
+                if self._bindings_module is not None:
+                    shutil.copy(
+                        self._bindings_module.__file__,
+                        os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1])
+                    )
+
+                script_file = os.path.join(working_dir, "timer_callgrind.py")
+                with open(script_file, "w") as f:
+                    f.write(self._construct_script(
+                        task_spec,
+                        globals=GlobalsBridge(globals, data_dir),
+                        number=number,
+                        repeats=repeats,
+                        collect_baseline=collect_baseline,
+                        error_log=error_log,
+                        stat_log=stat_log,
+                        bindings=self._bindings_module))
+
+                run_loop_cmd = ["python", script_file]
+            else:
+                if collect_baseline:
+                    raise AssertionError("collect_baseline must be False for non-Python timers")
+                run_loop_exec = cpp_jit.compile_callgrind_template(
+                    stmt=task_spec.stmt,
+                    setup=task_spec.setup,
+                    global_setup=task_spec.global_setup,
+                )
+                run_loop_cmd = [
+                    run_loop_exec,
+                    "--number", str(number),
+                    "--number-warmup", str(min(number, 10)),
+                    "--repeats", str(repeats),
+                    "--number-threads", str(task_spec.num_threads),
+                ]
+
+            valgrind_invocation, valgrind_invocation_output = run([
+                "valgrind",
+                "--tool=callgrind",
+                f"--callgrind-out-file={callgrind_out}",
+                "--dump-line=yes",
+                "--dump-instr=yes",
+                "--instr-atstart=yes",
+                "--collect-atstart=no",
+            ] + run_loop_cmd)
+
+            if valgrind_invocation.returncode:
+                error_report = ""
+                if os.path.exists(error_log):
+                    with open(error_log) as f:
+                        error_report = f.read()
+                if not error_report:
+                    error_report = "Unknown error.\n" + valgrind_invocation_output
+
+                raise OSError(f"Failed to collect callgrind profile:\n{error_report}")
+
+            def parse_output(fpath: str, inclusive: bool) -> FunctionCounts:
+                _annotate_invocation, annotate_invocation_output = run([
+                    "callgrind_annotate",
+                    f"--inclusive={'yes' if inclusive else 'no'}",
+                    "--threshold=100",
+                    "--show-percs=no",
+                    fpath
+                ], check=True)
+
+                total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS")
+                begin_pattern = re.compile(r"Ir\s+file:function")
+                function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$")
+
+                class ScanState(enum.Enum):
+                    SCANNING_FOR_TOTAL = 0
+                    SCANNING_FOR_START = 1
+                    PARSING = 2
+
+                scan_state = ScanState.SCANNING_FOR_TOTAL
+                fn_counts = []
+                for l in annotate_invocation_output.splitlines(keepends=False):
+                    if scan_state == ScanState.SCANNING_FOR_TOTAL:
+                        total_match = total_pattern.match(l)
+                        if total_match:
+                            program_totals = int(total_match.groups()[0].replace(",", ""))
+                            scan_state = ScanState.SCANNING_FOR_START
+
+                    elif scan_state == ScanState.SCANNING_FOR_START:
+                        if begin_pattern.match(l):
+                            scan_state = ScanState.PARSING
+
+                    else:
+                        if scan_state != ScanState.PARSING:
+                            raise AssertionError("Failed to enter PARSING state while parsing callgrind_annotate output")
+                        fn_match = function_pattern.match(l)
+                        if fn_match:
+                            ir_str, file_function = fn_match.groups()
+                            ir = int(ir_str.replace(",", ""))
+                            if ir == program_totals:  # type: ignore[possibly-undefined]
+                                # Callgrind includes some top level red herring symbols when
+                                # a program dumps multiple profiles.
+                                continue
+                            fn_counts.append(FunctionCount(ir, file_function))
+
+                        elif re.match(r"-+", l):
+                            # Ignore heading separator lines.
+                            continue
+
+                        else:
+                            break
+
+                if scan_state != ScanState.PARSING:
+                    raise AssertionError(f"Failed to parse {fpath}")
+                return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive)
+
+            def read_results(i: int) -> tuple[FunctionCounts, FunctionCounts, str | None]:
+                if i == repeats and not collect_baseline:
+                    # Null baseline.
+                    return (
+                        FunctionCounts((), inclusive=True),
+                        FunctionCounts((), inclusive=False),
+                        None,
+                    )
+
+                fpath = f"{callgrind_out}.{i + 1}"  # Callgrind one-indexes files.
+                callgrind_out_contents: str | None = None
+                if retain_out_file:
+                    with open(fpath) as f:
+                        callgrind_out_contents = f.read()
+
+                return (
+                    parse_output(fpath, inclusive=True),
+                    parse_output(fpath, inclusive=False),
+                    callgrind_out_contents
+                )
+
+            return tuple(read_results(i) for i in range(repeats + 1))
+        finally:
+            shutil.rmtree(working_dir)
+
+    @staticmethod
+    def _construct_script(
+        task_spec: common.TaskSpec,
+        globals: GlobalsBridge,
+        *,
+        number: int,
+        repeats: int,
+        collect_baseline: bool,
+        error_log: str,
+        stat_log: str,
+        bindings: CallgrindModuleType | None,
+    ) -> str:
+        def block_stmt(stmt: str, indent: int = 0) -> str:
+            """Partially unroll benchmark loop.
+
+            The naive template looks something like:
+                "for _ in range({number}): {stmt}"
+
+            However a loop in Python is surprisingly expensive, and significantly
+            increases the number of background Python instructions. So instead we
+            partially unroll the loops, with a block size of 100 chosen to keep
+            the instruction overhead from `range` low while also not ballooning
+            the size of the generated file.
+            """
+            block_size = 100
+            loop_count = number // block_size
+            if loop_count == 1:
+                # There is no point in having `for _ in range(1): ...` rather
+                # than just `...`, and this lets us save shave a few background
+                # instructions.
+                loop_count = 0
+            remainder = number - block_size * loop_count
+            blocked_stmt = ""
+
+            if loop_count:
+                unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4)
+                blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n"
+
+            if remainder:
+                blocked_stmt += "\n".join([stmt] * remainder)
+
+            return textwrap.indent(blocked_stmt, " " * indent)
+
+        pass_baseline = (
+            "callgrind_bindings._valgrind_toggle()\n"
+            f"{block_stmt('pass')}\n"
+            "callgrind_bindings._valgrind_toggle_and_dump_stats()"
+        )
+
+        return textwrap.dedent(r"""
+            import gc
+            import os
+            import pickle
+            import subprocess
+            import sys
+            import time
+
+            # Mitigate https://github.com/pytorch/pytorch/issues/37377
+            # which can sometimes cause the subprocess call to fail.
+            import numpy as np
+
+            import torch
+            torch.set_num_threads({num_threads})
+
+            {bindings_import}
+
+            PID = os.getpid()
+
+            def log_failure(msg):
+                with open({error_log_repr}, "wt") as f:
+                    f.write(msg)
+                sys.exit(1)
+
+            def check_result(completed_process):
+                if completed_process.returncode:
+                    log_failure(f"Command failed: {{' '.join(completed_process.args)}}")
+                return completed_process
+
+            # =============================================================================
+            # == Check that subprocess matches parent =====================================
+            # =============================================================================
+            if os.path.realpath(sys.executable) != "{parent_interpreter}":
+                log_failure(
+                    "Interpreter mismatch:\n"
+                    f"  {{os.path.realpath(sys.executable)}}\n    vs.\n  {parent_interpreter}"
+                )
+
+            if torch.__file__ != "{torch_file}":
+                log_failure(
+                    "PyTorch does not match expected file:\n"
+                    f"  {{torch.__file__}}\n    vs.\n  {torch_file}"
+                )
+
+            # =============================================================================
+            # == User specified setup =====================================================
+            # =============================================================================
+            # Load serialized globals
+            {load_globals}
+
+            # User setup str
+            {setup}
+
+            for _ in range({warmup_number}):
+            {indented_stmt}
+
+            # =============================================================================
+            # == Callgrind management =====================================================
+            # =============================================================================
+            with open("{stat_log}", "wb") as stat_file:
+                # If many instances of callgrind are running at once, the output of
+                # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE`
+                # to deadlock. So instead we use a file.
+                callgrind_stat = check_result(subprocess.run(
+                    ["callgrind_control", "--stat"],
+                    stdout=stat_file,
+                    stderr=subprocess.STDOUT,
+                ))
+
+            with open("{stat_log}", "rt") as stat_file:
+                stat_lines = stat_file.read().splitlines()
+
+            if f"PID {{PID}}: python {{__file__}}" not in stat_lines:
+                log_failure("Process does not appear to be running callgrind.")
+
+            gc.collect()
+            time.sleep(0.01)
+
+            # =============================================================================
+            # == User code block ==========================================================
+            # =============================================================================
+            for _ in range({repeats}):
+                callgrind_bindings._valgrind_toggle()
+            {blocked_stmt}
+                callgrind_bindings._valgrind_toggle_and_dump_stats()
+                gc.collect()
+
+            {baseline}
+        """).strip().format(
+            indented_stmt=textwrap.indent(task_spec.stmt, " " * 4),
+            blocked_stmt=block_stmt(task_spec.stmt, indent=4),
+            baseline=(pass_baseline if collect_baseline else ""),
+            number=number,
+            repeats=repeats,
+            load_globals=globals.construct(),
+            setup=task_spec.setup,
+            warmup_number=min(number, 10),
+            num_threads=task_spec.num_threads,
+            error_log_repr=repr(error_log),
+            stat_log=stat_log,
+            parent_interpreter=os.path.realpath(sys.executable),
+            torch_file=torch.__file__,
+            bindings_import=(
+                "import torch._C as callgrind_bindings" if bindings is None
+                else f"import {bindings.__name__} as callgrind_bindings"),
+        )
+
+
+CALLGRIND_SINGLETON: _ValgrindWrapper | None = None
+def wrapper_singleton() -> _ValgrindWrapper:
+    global CALLGRIND_SINGLETON
+    if CALLGRIND_SINGLETON is None:
+        CALLGRIND_SINGLETON = _ValgrindWrapper()
+    return CALLGRIND_SINGLETON
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
new file mode 100644
index 0000000000000000000000000000000000000000..d33dd30932aa86b8284cb93d0e29ec646e820197
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
@@ -0,0 +1,7157 @@
+/* -*- c -*-
+   ----------------------------------------------------------------
+
+   Notice that the following BSD-style license applies to this one
+   file (valgrind.h) only.  The rest of Valgrind is licensed under the
+   terms of the GNU General Public License, version 2, unless
+   otherwise indicated.  See the COPYING file in the source
+   distribution for details.
+
+   ----------------------------------------------------------------
+
+   This file is part of Valgrind, a dynamic binary instrumentation
+   framework.
+
+   Copyright (C) 2000-2017 Julian Seward.  All rights reserved.
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   1. Redistributions of source code must retain the above copyright
+      notice, this list of conditions and the following disclaimer.
+
+   2. The origin of this software must not be misrepresented; you must 
+      not claim that you wrote the original software.  If you use this 
+      software in a product, an acknowledgment in the product 
+      documentation would be appreciated but is not required.
+
+   3. Altered source versions must be plainly marked as such, and must
+      not be misrepresented as being the original software.
+
+   4. The name of the author may not be used to endorse or promote 
+      products derived from this software without specific prior written 
+      permission.
+
+   THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
+   OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+   WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+   ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+   DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
+   GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+   INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+   WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+   ----------------------------------------------------------------
+
+   Notice that the above BSD-style license applies to this one file
+   (valgrind.h) only.  The entire rest of Valgrind is licensed under
+   the terms of the GNU General Public License, version 2.  See the
+   COPYING file in the source distribution for details.
+
+   ---------------------------------------------------------------- 
+*/
+
+
+/* This file is for inclusion into client (your!) code.
+
+   You can use these macros to manipulate and query Valgrind's 
+   execution inside your own programs.
+
+   The resulting executables will still run without Valgrind, just a
+   little bit more slowly than they otherwise would, but otherwise
+   unchanged.  When not running on valgrind, each client request
+   consumes very few (eg. 7) instructions, so the resulting performance
+   loss is negligible unless you plan to execute client requests
+   millions of times per second.  Nevertheless, if that is still a
+   problem, you can compile with the NVALGRIND symbol defined (gcc
+   -DNVALGRIND) so that client requests are not even compiled in.  */
+
+#ifndef __VALGRIND_H
+#define __VALGRIND_H
+
+
+/* ------------------------------------------------------------------ */
+/* VERSION NUMBER OF VALGRIND                                         */
+/* ------------------------------------------------------------------ */
+
+/* Specify Valgrind's version number, so that user code can
+   conditionally compile based on our version number.  Note that these
+   were introduced at version 3.6 and so do not exist in version 3.5
+   or earlier.  The recommended way to use them to check for "version
+   X.Y or later" is (eg)
+
+#if defined(__VALGRIND_MAJOR__) && defined(__VALGRIND_MINOR__)   \
+    && (__VALGRIND_MAJOR__ > 3                                   \
+        || (__VALGRIND_MAJOR__ == 3 && __VALGRIND_MINOR__ >= 6))
+*/
+#define __VALGRIND_MAJOR__    3
+#define __VALGRIND_MINOR__    17
+
+
+#include 
+
+/* Nb: this file might be included in a file compiled with -ansi.  So
+   we can't use C++ style "//" comments nor the "asm" keyword (instead
+   use "__asm__"). */
+
+/* Derive some tags indicating what the target platform is.  Note
+   that in this file we're using the compiler's CPP symbols for
+   identifying architectures, which are different to the ones we use
+   within the rest of Valgrind.  Note, __powerpc__ is active for both
+   32 and 64-bit PPC, whereas __powerpc64__ is only active for the
+   latter (on Linux, that is).
+
+   Misc note: how to find out what's predefined in gcc by default:
+   gcc -Wp,-dM somefile.c
+*/
+#undef PLAT_x86_darwin
+#undef PLAT_amd64_darwin
+#undef PLAT_x86_win32
+#undef PLAT_amd64_win64
+#undef PLAT_x86_linux
+#undef PLAT_amd64_linux
+#undef PLAT_ppc32_linux
+#undef PLAT_ppc64be_linux
+#undef PLAT_ppc64le_linux
+#undef PLAT_arm_linux
+#undef PLAT_arm64_linux
+#undef PLAT_s390x_linux
+#undef PLAT_mips32_linux
+#undef PLAT_mips64_linux
+#undef PLAT_nanomips_linux
+#undef PLAT_x86_solaris
+#undef PLAT_amd64_solaris
+
+
+#if defined(__APPLE__) && defined(__i386__)
+#  define PLAT_x86_darwin 1
+#elif defined(__APPLE__) && defined(__x86_64__)
+#  define PLAT_amd64_darwin 1
+#elif (defined(__MINGW32__) && defined(__i386__)) \
+      || defined(__CYGWIN32__) \
+      || (defined(_WIN32) && defined(_M_IX86))
+#  define PLAT_x86_win32 1
+#elif (defined(__MINGW32__) && defined(__x86_64__)) \
+      || (defined(_WIN32) && defined(_M_X64))
+/* __MINGW32__ and _WIN32 are defined in 64 bit mode as well. */
+#  define PLAT_amd64_win64 1
+#elif defined(__linux__) && defined(__i386__)
+#  define PLAT_x86_linux 1
+#elif defined(__linux__) && defined(__x86_64__) && !defined(__ILP32__)
+#  define PLAT_amd64_linux 1
+#elif defined(__linux__) && defined(__powerpc__) && !defined(__powerpc64__)
+#  define PLAT_ppc32_linux 1
+#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF != 2
+/* Big Endian uses ELF version 1 */
+#  define PLAT_ppc64be_linux 1
+#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF == 2
+/* Little Endian uses ELF version 2 */
+#  define PLAT_ppc64le_linux 1
+#elif defined(__linux__) && defined(__arm__) && !defined(__aarch64__)
+#  define PLAT_arm_linux 1
+#elif defined(__linux__) && defined(__aarch64__) && !defined(__arm__)
+#  define PLAT_arm64_linux 1
+#elif defined(__linux__) && defined(__s390__) && defined(__s390x__)
+#  define PLAT_s390x_linux 1
+#elif defined(__linux__) && defined(__mips__) && (__mips==64)
+#  define PLAT_mips64_linux 1
+#elif defined(__linux__) && defined(__mips__) && (__mips==32)
+#  define PLAT_mips32_linux 1
+#elif defined(__linux__) && defined(__nanomips__)
+#  define PLAT_nanomips_linux 1
+#elif defined(__sun) && defined(__i386__)
+#  define PLAT_x86_solaris 1
+#elif defined(__sun) && defined(__x86_64__)
+#  define PLAT_amd64_solaris 1
+#else
+/* If we're not compiling for our target platform, don't generate
+   any inline asms.  */
+#  if !defined(NVALGRIND)
+#    define NVALGRIND 1
+#  endif
+#endif
+
+
+/* ------------------------------------------------------------------ */
+/* ARCHITECTURE SPECIFICS for SPECIAL INSTRUCTIONS.  There is nothing */
+/* in here of use to end-users -- skip to the next section.           */
+/* ------------------------------------------------------------------ */
+
+/*
+ * VALGRIND_DO_CLIENT_REQUEST(): a statement that invokes a Valgrind client
+ * request. Accepts both pointers and integers as arguments.
+ *
+ * VALGRIND_DO_CLIENT_REQUEST_STMT(): a statement that invokes a Valgrind
+ * client request that does not return a value.
+
+ * VALGRIND_DO_CLIENT_REQUEST_EXPR(): a C expression that invokes a Valgrind
+ * client request and whose value equals the client request result.  Accepts
+ * both pointers and integers as arguments.  Note that such calls are not
+ * necessarily pure functions -- they may have side effects.
+ */
+
+#define VALGRIND_DO_CLIENT_REQUEST(_zzq_rlval, _zzq_default,            \
+                                   _zzq_request, _zzq_arg1, _zzq_arg2,  \
+                                   _zzq_arg3, _zzq_arg4, _zzq_arg5)     \
+  do { (_zzq_rlval) = VALGRIND_DO_CLIENT_REQUEST_EXPR((_zzq_default),   \
+                        (_zzq_request), (_zzq_arg1), (_zzq_arg2),       \
+                        (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0)
+
+#define VALGRIND_DO_CLIENT_REQUEST_STMT(_zzq_request, _zzq_arg1,        \
+                           _zzq_arg2,  _zzq_arg3, _zzq_arg4, _zzq_arg5) \
+  do { (void) VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                        \
+                    (_zzq_request), (_zzq_arg1), (_zzq_arg2),           \
+                    (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0)
+
+#if defined(NVALGRIND)
+
+/* Define NVALGRIND to completely remove the Valgrind magic sequence
+   from the compiled code (analogous to NDEBUG's effects on
+   assert()) */
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+      (_zzq_default)
+
+#else  /* ! NVALGRIND */
+
+/* The following defines the magic code sequences which the JITter
+   spots and handles magically.  Don't look too closely at them as
+   they will rot your brain.
+
+   The assembly code sequences for all architectures is in this one
+   file.  This is because this file must be stand-alone, and we don't
+   want to have multiple files.
+
+   For VALGRIND_DO_CLIENT_REQUEST, we must ensure that the default
+   value gets put in the return slot, so that everything works when
+   this is executed not under Valgrind.  Args are passed in a memory
+   block, and so there's no intrinsic limit to the number that could
+   be passed, but it's currently five.
+   
+   The macro args are: 
+      _zzq_rlval    result lvalue
+      _zzq_default  default value (result returned when running on real CPU)
+      _zzq_request  request code
+      _zzq_arg1..5  request params
+
+   The other two macros are used to support function wrapping, and are
+   a lot simpler.  VALGRIND_GET_NR_CONTEXT returns the value of the
+   guest's NRADDR pseudo-register and whatever other information is
+   needed to safely run the call original from the wrapper: on
+   ppc64-linux, the R2 value at the divert point is also needed.  This
+   information is abstracted into a user-visible type, OrigFn.
+
+   VALGRIND_CALL_NOREDIR_* behaves the same as the following on the
+   guest, but guarantees that the branch instruction will not be
+   redirected: x86: call *%eax, amd64: call *%rax, ppc32/ppc64:
+   branch-and-link-to-r11.  VALGRIND_CALL_NOREDIR is just text, not a
+   complete inline asm, since it needs to be combined with more magic
+   inline asm stuff to be useful.
+*/
+
+/* ----------------- x86-{linux,darwin,solaris} ---------------- */
+
+#if defined(PLAT_x86_linux)  ||  defined(PLAT_x86_darwin)  \
+    ||  (defined(PLAT_x86_win32) && defined(__GNUC__)) \
+    ||  defined(PLAT_x86_solaris)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "roll $3,  %%edi ; roll $13, %%edi\n\t"      \
+                     "roll $29, %%edi ; roll $19, %%edi\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+  __extension__                                                   \
+  ({volatile unsigned int _zzq_args[6];                           \
+    volatile unsigned int _zzq_result;                            \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %EDX = client_request ( %EAX ) */         \
+                     "xchgl %%ebx,%%ebx"                          \
+                     : "=d" (_zzq_result)                         \
+                     : "a" (&_zzq_args[0]), "0" (_zzq_default)    \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned int __addr;                                 \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %EAX = guest_NRADDR */                    \
+                     "xchgl %%ecx,%%ecx"                          \
+                     : "=a" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_EAX                                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* call-noredir *%EAX */                     \
+                     "xchgl %%edx,%%edx\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "xchgl %%edi,%%edi\n\t"                     \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_x86_linux || PLAT_x86_darwin || (PLAT_x86_win32 && __GNUC__)
+          || PLAT_x86_solaris */
+
+/* ------------------------- x86-Win32 ------------------------- */
+
+#if defined(PLAT_x86_win32) && !defined(__GNUC__)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#if defined(_MSC_VER)
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     __asm rol edi, 3  __asm rol edi, 13          \
+                     __asm rol edi, 29 __asm rol edi, 19
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+    valgrind_do_client_request_expr((uintptr_t)(_zzq_default),    \
+        (uintptr_t)(_zzq_request), (uintptr_t)(_zzq_arg1),        \
+        (uintptr_t)(_zzq_arg2), (uintptr_t)(_zzq_arg3),           \
+        (uintptr_t)(_zzq_arg4), (uintptr_t)(_zzq_arg5))
+
+static __inline uintptr_t
+valgrind_do_client_request_expr(uintptr_t _zzq_default, uintptr_t _zzq_request,
+                                uintptr_t _zzq_arg1, uintptr_t _zzq_arg2,
+                                uintptr_t _zzq_arg3, uintptr_t _zzq_arg4,
+                                uintptr_t _zzq_arg5)
+{
+    volatile uintptr_t _zzq_args[6];
+    volatile unsigned int _zzq_result;
+    _zzq_args[0] = (uintptr_t)(_zzq_request);
+    _zzq_args[1] = (uintptr_t)(_zzq_arg1);
+    _zzq_args[2] = (uintptr_t)(_zzq_arg2);
+    _zzq_args[3] = (uintptr_t)(_zzq_arg3);
+    _zzq_args[4] = (uintptr_t)(_zzq_arg4);
+    _zzq_args[5] = (uintptr_t)(_zzq_arg5);
+    __asm { __asm lea eax, _zzq_args __asm mov edx, _zzq_default
+            __SPECIAL_INSTRUCTION_PREAMBLE
+            /* %EDX = client_request ( %EAX ) */
+            __asm xchg ebx,ebx
+            __asm mov _zzq_result, edx
+    }
+    return _zzq_result;
+}
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned int __addr;                                 \
+    __asm { __SPECIAL_INSTRUCTION_PREAMBLE                        \
+            /* %EAX = guest_NRADDR */                             \
+            __asm xchg ecx,ecx                                    \
+            __asm mov __addr, eax                                 \
+    }                                                             \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_EAX ERROR
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm { __SPECIAL_INSTRUCTION_PREAMBLE                       \
+            __asm xchg edi,edi                                   \
+    }                                                            \
+ } while (0)
+
+#else
+#error Unsupported compiler.
+#endif
+
+#endif /* PLAT_x86_win32 */
+
+/* ----------------- amd64-{linux,darwin,solaris} --------------- */
+
+#if defined(PLAT_amd64_linux)  ||  defined(PLAT_amd64_darwin) \
+    ||  defined(PLAT_amd64_solaris) \
+    ||  (defined(PLAT_amd64_win64) && defined(__GNUC__))
+
+typedef
+   struct { 
+      unsigned long int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "rolq $3,  %%rdi ; rolq $13, %%rdi\n\t"      \
+                     "rolq $61, %%rdi ; rolq $51, %%rdi\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+    __extension__                                                 \
+    ({ volatile unsigned long int _zzq_args[6];                   \
+    volatile unsigned long int _zzq_result;                       \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %RDX = client_request ( %RAX ) */         \
+                     "xchgq %%rbx,%%rbx"                          \
+                     : "=d" (_zzq_result)                         \
+                     : "a" (&_zzq_args[0]), "0" (_zzq_default)    \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_result;                                                  \
+    })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned long int __addr;                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %RAX = guest_NRADDR */                    \
+                     "xchgq %%rcx,%%rcx"                          \
+                     : "=a" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_RAX                                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* call-noredir *%RAX */                     \
+                     "xchgq %%rdx,%%rdx\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "xchgq %%rdi,%%rdi\n\t"                     \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */
+
+/* ------------------------- amd64-Win64 ------------------------- */
+
+#if defined(PLAT_amd64_win64) && !defined(__GNUC__)
+
+#error Unsupported compiler.
+
+#endif /* PLAT_amd64_win64 */
+
+/* ------------------------ ppc32-linux ------------------------ */
+
+#if defined(PLAT_ppc32_linux)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                    "rlwinm 0,0,3,0,31  ; rlwinm 0,0,13,0,31\n\t" \
+                    "rlwinm 0,0,29,0,31 ; rlwinm 0,0,19,0,31\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+    __extension__                                                 \
+  ({         unsigned int  _zzq_args[6];                          \
+             unsigned int  _zzq_result;                           \
+             unsigned int* _zzq_ptr;                              \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    _zzq_ptr = _zzq_args;                                         \
+    __asm__ volatile("mr 3,%1\n\t" /*default*/                    \
+                     "mr 4,%2\n\t" /*ptr*/                        \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = client_request ( %R4 ) */           \
+                     "or 1,1,1\n\t"                               \
+                     "mr %0,3"     /*result*/                     \
+                     : "=b" (_zzq_result)                         \
+                     : "b" (_zzq_default), "b" (_zzq_ptr)         \
+                     : "cc", "memory", "r3", "r4");               \
+    _zzq_result;                                                  \
+    })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned int __addr;                                          \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR */                     \
+                     "or 2,2,2\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R11 */       \
+                     "or 3,3,3\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or 5,5,5\n\t"                              \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_ppc32_linux */
+
+/* ------------------------ ppc64-linux ------------------------ */
+
+#if defined(PLAT_ppc64be_linux)
+
+typedef
+   struct { 
+      unsigned long int nraddr; /* where's the code? */
+      unsigned long int r2;  /* what tocptr do we need? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "rotldi 0,0,3  ; rotldi 0,0,13\n\t"          \
+                     "rotldi 0,0,61 ; rotldi 0,0,51\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({         unsigned long int  _zzq_args[6];                     \
+             unsigned long int  _zzq_result;                      \
+             unsigned long int* _zzq_ptr;                         \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    _zzq_ptr = _zzq_args;                                         \
+    __asm__ volatile("mr 3,%1\n\t" /*default*/                    \
+                     "mr 4,%2\n\t" /*ptr*/                        \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = client_request ( %R4 ) */           \
+                     "or 1,1,1\n\t"                               \
+                     "mr %0,3"     /*result*/                     \
+                     : "=b" (_zzq_result)                         \
+                     : "b" (_zzq_default), "b" (_zzq_ptr)         \
+                     : "cc", "memory", "r3", "r4");               \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned long int __addr;                                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR */                     \
+                     "or 2,2,2\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR_GPR2 */                \
+                     "or 4,4,4\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->r2 = __addr;                                       \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R11 */       \
+                     "or 3,3,3\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or 5,5,5\n\t"                              \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_ppc64be_linux */
+
+#if defined(PLAT_ppc64le_linux)
+
+typedef
+   struct {
+      unsigned long int nraddr; /* where's the code? */
+      unsigned long int r2;     /* what tocptr do we need? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "rotldi 0,0,3  ; rotldi 0,0,13\n\t"          \
+                     "rotldi 0,0,61 ; rotldi 0,0,51\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({         unsigned long int  _zzq_args[6];                     \
+             unsigned long int  _zzq_result;                      \
+             unsigned long int* _zzq_ptr;                         \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    _zzq_ptr = _zzq_args;                                         \
+    __asm__ volatile("mr 3,%1\n\t" /*default*/                    \
+                     "mr 4,%2\n\t" /*ptr*/                        \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = client_request ( %R4 ) */           \
+                     "or 1,1,1\n\t"                               \
+                     "mr %0,3"     /*result*/                     \
+                     : "=b" (_zzq_result)                         \
+                     : "b" (_zzq_default), "b" (_zzq_ptr)         \
+                     : "cc", "memory", "r3", "r4");               \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned long int __addr;                                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR */                     \
+                     "or 2,2,2\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR_GPR2 */                \
+                     "or 4,4,4\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->r2 = __addr;                                       \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R12 */       \
+                     "or 3,3,3\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or 5,5,5\n\t"                              \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_ppc64le_linux */
+
+/* ------------------------- arm-linux ------------------------- */
+
+#if defined(PLAT_arm_linux)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+            "mov r12, r12, ror #3  ; mov r12, r12, ror #13 \n\t"  \
+            "mov r12, r12, ror #29 ; mov r12, r12, ror #19 \n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({volatile unsigned int  _zzq_args[6];                          \
+    volatile unsigned int  _zzq_result;                           \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    __asm__ volatile("mov r3, %1\n\t" /*default*/                 \
+                     "mov r4, %2\n\t" /*ptr*/                     \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* R3 = client_request ( R4 ) */             \
+                     "orr r10, r10, r10\n\t"                      \
+                     "mov %0, r3"     /*result*/                  \
+                     : "=r" (_zzq_result)                         \
+                     : "r" (_zzq_default), "r" (&_zzq_args[0])    \
+                     : "cc","memory", "r3", "r4");                \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned int __addr;                                          \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* R3 = guest_NRADDR */                      \
+                     "orr r11, r11, r11\n\t"                      \
+                     "mov %0, r3"                                 \
+                     : "=r" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R4 */        \
+                     "orr r12, r12, r12\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "orr r9, r9, r9\n\t"                        \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_arm_linux */
+
+/* ------------------------ arm64-linux ------------------------- */
+
+#if defined(PLAT_arm64_linux)
+
+typedef
+   struct { 
+      unsigned long int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+            "ror x12, x12, #3  ;  ror x12, x12, #13 \n\t"         \
+            "ror x12, x12, #51 ;  ror x12, x12, #61 \n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({volatile unsigned long int  _zzq_args[6];                     \
+    volatile unsigned long int  _zzq_result;                      \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    __asm__ volatile("mov x3, %1\n\t" /*default*/                 \
+                     "mov x4, %2\n\t" /*ptr*/                     \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* X3 = client_request ( X4 ) */             \
+                     "orr x10, x10, x10\n\t"                      \
+                     "mov %0, x3"     /*result*/                  \
+                     : "=r" (_zzq_result)                         \
+                     : "r" ((unsigned long int)(_zzq_default)),   \
+                       "r" (&_zzq_args[0])                        \
+                     : "cc","memory", "x3", "x4");                \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned long int __addr;                                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* X3 = guest_NRADDR */                      \
+                     "orr x11, x11, x11\n\t"                      \
+                     "mov %0, x3"                                 \
+                     : "=r" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "x3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir X8 */          \
+                     "orr x12, x12, x12\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "orr x9, x9, x9\n\t"                        \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_arm64_linux */
+
+/* ------------------------ s390x-linux ------------------------ */
+
+#if defined(PLAT_s390x_linux)
+
+typedef
+  struct {
+     unsigned long int nraddr; /* where's the code? */
+  }
+  OrigFn;
+
+/* __SPECIAL_INSTRUCTION_PREAMBLE will be used to identify Valgrind specific
+ * code. This detection is implemented in platform specific toIR.c
+ * (e.g. VEX/priv/guest_s390_decoder.c).
+ */
+#define __SPECIAL_INSTRUCTION_PREAMBLE                           \
+                     "lr 15,15\n\t"                              \
+                     "lr 1,1\n\t"                                \
+                     "lr 2,2\n\t"                                \
+                     "lr 3,3\n\t"
+
+#define __CLIENT_REQUEST_CODE "lr 2,2\n\t"
+#define __GET_NR_CONTEXT_CODE "lr 3,3\n\t"
+#define __CALL_NO_REDIR_CODE  "lr 4,4\n\t"
+#define __VEX_INJECT_IR_CODE  "lr 5,5\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                         \
+       _zzq_default, _zzq_request,                               \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+  __extension__                                                  \
+ ({volatile unsigned long int _zzq_args[6];                      \
+   volatile unsigned long int _zzq_result;                       \
+   _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+   _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+   _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+   _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+   _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+   _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+   __asm__ volatile(/* r2 = args */                              \
+                    "lgr 2,%1\n\t"                               \
+                    /* r3 = default */                           \
+                    "lgr 3,%2\n\t"                               \
+                    __SPECIAL_INSTRUCTION_PREAMBLE               \
+                    __CLIENT_REQUEST_CODE                        \
+                    /* results = r3 */                           \
+                    "lgr %0, 3\n\t"                              \
+                    : "=d" (_zzq_result)                         \
+                    : "a" (&_zzq_args[0]), "0" (_zzq_default)    \
+                    : "cc", "2", "3", "memory"                   \
+                   );                                            \
+   _zzq_result;                                                  \
+ })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                      \
+ { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+   volatile unsigned long int __addr;                            \
+   __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                    __GET_NR_CONTEXT_CODE                        \
+                    "lgr %0, 3\n\t"                              \
+                    : "=a" (__addr)                              \
+                    :                                            \
+                    : "cc", "3", "memory"                        \
+                   );                                            \
+   _zzq_orig->nraddr = __addr;                                   \
+ }
+
+#define VALGRIND_CALL_NOREDIR_R1                                 \
+                    __SPECIAL_INSTRUCTION_PREAMBLE               \
+                    __CALL_NO_REDIR_CODE
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     __VEX_INJECT_IR_CODE);                      \
+ } while (0)
+
+#endif /* PLAT_s390x_linux */
+
+/* ------------------------- mips32-linux ---------------- */
+
+#if defined(PLAT_mips32_linux)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+/* .word  0x342
+ * .word  0x742
+ * .word  0xC2
+ * .word  0x4C2*/
+#define __SPECIAL_INSTRUCTION_PREAMBLE          \
+                     "srl $0, $0, 13\n\t"       \
+                     "srl $0, $0, 29\n\t"       \
+                     "srl $0, $0, 3\n\t"        \
+                     "srl $0, $0, 19\n\t"
+                    
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+       _zzq_default, _zzq_request,                                \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)     \
+  __extension__                                                   \
+  ({ volatile unsigned int _zzq_args[6];                          \
+    volatile unsigned int _zzq_result;                            \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+        __asm__ volatile("move $11, %1\n\t" /*default*/           \
+                     "move $12, %2\n\t" /*ptr*/                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* T3 = client_request ( T4 ) */             \
+                     "or $13, $13, $13\n\t"                       \
+                     "move %0, $11\n\t"     /*result*/            \
+                     : "=r" (_zzq_result)                         \
+                     : "r" (_zzq_default), "r" (&_zzq_args[0])    \
+                     : "$11", "$12", "memory");                   \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned int __addr;                                 \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %t9 = guest_NRADDR */                     \
+                     "or $14, $14, $14\n\t"                       \
+                     "move %0, $11"     /*result*/                \
+                     : "=r" (__addr)                              \
+                     :                                            \
+                     : "$11"                                      \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_T9                                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE              \
+                     /* call-noredir *%t9 */                     \
+                     "or $15, $15, $15\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or $11, $11, $11\n\t"                      \
+                    );                                           \
+ } while (0)
+
+
+#endif /* PLAT_mips32_linux */
+
+/* ------------------------- mips64-linux ---------------- */
+
+#if defined(PLAT_mips64_linux)
+
+typedef
+   struct {
+      unsigned long nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+/* dsll $0,$0, 3
+ * dsll $0,$0, 13
+ * dsll $0,$0, 29
+ * dsll $0,$0, 19*/
+#define __SPECIAL_INSTRUCTION_PREAMBLE                              \
+                     "dsll $0,$0, 3 ; dsll $0,$0,13\n\t"            \
+                     "dsll $0,$0,29 ; dsll $0,$0,19\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                            \
+       _zzq_default, _zzq_request,                                  \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)       \
+  __extension__                                                     \
+  ({ volatile unsigned long int _zzq_args[6];                       \
+    volatile unsigned long int _zzq_result;                         \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);               \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                  \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                  \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                  \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                  \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                  \
+        __asm__ volatile("move $11, %1\n\t" /*default*/             \
+                         "move $12, %2\n\t" /*ptr*/                 \
+                         __SPECIAL_INSTRUCTION_PREAMBLE             \
+                         /* $11 = client_request ( $12 ) */         \
+                         "or $13, $13, $13\n\t"                     \
+                         "move %0, $11\n\t"     /*result*/          \
+                         : "=r" (_zzq_result)                       \
+                         : "r" (_zzq_default), "r" (&_zzq_args[0])  \
+                         : "$11", "$12", "memory");                 \
+    _zzq_result;                                                    \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                         \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                     \
+    volatile unsigned long int __addr;                              \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* $11 = guest_NRADDR */                       \
+                     "or $14, $14, $14\n\t"                         \
+                     "move %0, $11"     /*result*/                  \
+                     : "=r" (__addr)                                \
+                     :                                              \
+                     : "$11");                                      \
+    _zzq_orig->nraddr = __addr;                                     \
+  }
+
+#define VALGRIND_CALL_NOREDIR_T9                                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* call-noredir $25 */                         \
+                     "or $15, $15, $15\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                    \
+ do {                                                               \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     "or $11, $11, $11\n\t"                         \
+                    );                                              \
+ } while (0)
+
+#endif /* PLAT_mips64_linux */
+
+#if defined(PLAT_nanomips_linux)
+
+typedef
+   struct {
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+/*
+   8000 c04d  srl  zero, zero, 13
+   8000 c05d  srl  zero, zero, 29
+   8000 c043  srl  zero, zero,  3
+   8000 c053  srl  zero, zero, 19
+*/
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE "srl[32] $zero, $zero, 13 \n\t" \
+                                       "srl[32] $zero, $zero, 29 \n\t" \
+                                       "srl[32] $zero, $zero, 3  \n\t" \
+                                       "srl[32] $zero, $zero, 19 \n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+       _zzq_default, _zzq_request,                                \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)     \
+  __extension__                                                   \
+  ({ volatile unsigned int _zzq_args[6];                          \
+    volatile unsigned int _zzq_result;                            \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    __asm__ volatile("move $a7, %1\n\t" /* default */             \
+                     "move $t0, %2\n\t" /* ptr */                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* $a7 = client_request( $t0 ) */            \
+                     "or[32] $t0, $t0, $t0\n\t"                   \
+                     "move %0, $a7\n\t"     /* result */          \
+                     : "=r" (_zzq_result)                         \
+                     : "r" (_zzq_default), "r" (&_zzq_args[0])    \
+                     : "$a7", "$t0", "memory");                   \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                         \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                     \
+    volatile unsigned long int __addr;                              \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* $a7 = guest_NRADDR */                       \
+                     "or[32] $t1, $t1, $t1\n\t"                     \
+                     "move %0, $a7"     /*result*/                  \
+                     : "=r" (__addr)                                \
+                     :                                              \
+                     : "$a7");                                      \
+    _zzq_orig->nraddr = __addr;                                     \
+  }
+
+#define VALGRIND_CALL_NOREDIR_T9                                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* call-noredir $25 */                         \
+                     "or[32] $t2, $t2, $t2\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                    \
+ do {                                                               \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     "or[32] $t3, $t3, $t3\n\t"                     \
+                    );                                              \
+ } while (0)
+
+#endif
+/* Insert assembly code for other platforms here... */
+
+#endif /* NVALGRIND */
+
+
+/* ------------------------------------------------------------------ */
+/* PLATFORM SPECIFICS for FUNCTION WRAPPING.  This is all very        */
+/* ugly.  It's the least-worst tradeoff I can think of.               */
+/* ------------------------------------------------------------------ */
+
+/* This section defines magic (a.k.a appalling-hack) macros for doing
+   guaranteed-no-redirection macros, so as to get from function
+   wrappers to the functions they are wrapping.  The whole point is to
+   construct standard call sequences, but to do the call itself with a
+   special no-redirect call pseudo-instruction that the JIT
+   understands and handles specially.  This section is long and
+   repetitious, and I can't see a way to make it shorter.
+
+   The naming scheme is as follows:
+
+      CALL_FN_{W,v}_{v,W,WW,WWW,WWWW,5W,6W,7W,etc}
+
+   'W' stands for "word" and 'v' for "void".  Hence there are
+   different macros for calling arity 0, 1, 2, 3, 4, etc, functions,
+   and for each, the possibility of returning a word-typed result, or
+   no result.
+*/
+
+/* Use these to write the name of your wrapper.  NOTE: duplicates
+   VG_WRAP_FUNCTION_Z{U,Z} in pub_tool_redir.h.  NOTE also: inserts
+   the default behaviour equivalance class tag "0000" into the name.
+   See pub_tool_redir.h for details -- normally you don't need to
+   think about this, though. */
+
+/* Use an extra level of macroisation so as to ensure the soname/fnname
+   args are fully macro-expanded before pasting them together. */
+#define VG_CONCAT4(_aa,_bb,_cc,_dd) _aa##_bb##_cc##_dd
+
+#define I_WRAP_SONAME_FNNAME_ZU(soname,fnname)                    \
+   VG_CONCAT4(_vgw00000ZU_,soname,_,fnname)
+
+#define I_WRAP_SONAME_FNNAME_ZZ(soname,fnname)                    \
+   VG_CONCAT4(_vgw00000ZZ_,soname,_,fnname)
+
+/* Use this macro from within a wrapper function to collect the
+   context (address and possibly other info) of the original function.
+   Once you have that you can then use it in one of the CALL_FN_
+   macros.  The type of the argument _lval is OrigFn. */
+#define VALGRIND_GET_ORIG_FN(_lval)  VALGRIND_GET_NR_CONTEXT(_lval)
+
+/* Also provide end-user facilities for function replacement, rather
+   than wrapping.  A replacement function differs from a wrapper in
+   that it has no way to get hold of the original function being
+   called, and hence no way to call onwards to it.  In a replacement
+   function, VALGRIND_GET_ORIG_FN always returns zero. */
+
+#define I_REPLACE_SONAME_FNNAME_ZU(soname,fnname)                 \
+   VG_CONCAT4(_vgr00000ZU_,soname,_,fnname)
+
+#define I_REPLACE_SONAME_FNNAME_ZZ(soname,fnname)                 \
+   VG_CONCAT4(_vgr00000ZZ_,soname,_,fnname)
+
+/* Derivatives of the main macros below, for calling functions
+   returning void. */
+
+#define CALL_FN_v_v(fnptr)                                        \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_v(_junk,fnptr); } while (0)
+
+#define CALL_FN_v_W(fnptr, arg1)                                  \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_W(_junk,fnptr,arg1); } while (0)
+
+#define CALL_FN_v_WW(fnptr, arg1,arg2)                            \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_WW(_junk,fnptr,arg1,arg2); } while (0)
+
+#define CALL_FN_v_WWW(fnptr, arg1,arg2,arg3)                      \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_WWW(_junk,fnptr,arg1,arg2,arg3); } while (0)
+
+#define CALL_FN_v_WWWW(fnptr, arg1,arg2,arg3,arg4)                \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_WWWW(_junk,fnptr,arg1,arg2,arg3,arg4); } while (0)
+
+#define CALL_FN_v_5W(fnptr, arg1,arg2,arg3,arg4,arg5)             \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_5W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5); } while (0)
+
+#define CALL_FN_v_6W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6)        \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_6W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6); } while (0)
+
+#define CALL_FN_v_7W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6,arg7)   \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_7W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6,arg7); } while (0)
+
+/* ----------------- x86-{linux,darwin,solaris} ---------------- */
+
+#if defined(PLAT_x86_linux)  ||  defined(PLAT_x86_darwin) \
+    ||  defined(PLAT_x86_solaris)
+
+/* These regs are trashed by the hidden call.  No need to mention eax
+   as gcc can already see that, plus causes gcc to bomb. */
+#define __CALLER_SAVED_REGS /*"eax"*/ "ecx", "edx"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "movl %%esp,%%edi\n\t"               \
+      "andl $0xfffffff0,%%esp\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "movl %%edi,%%esp\n\t"
+
+/* These CALL_FN_ macros assume that on x86-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $12, %%esp\n\t"                                    \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $8, %%esp\n\t"                                     \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $4, %%esp\n\t"                                     \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $12, %%esp\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $8, %%esp\n\t"                                     \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $4, %%esp\n\t"                                     \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $12, %%esp\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $8, %%esp\n\t"                                     \
+         "pushl 40(%%eax)\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $4, %%esp\n\t"                                     \
+         "pushl 44(%%eax)\n\t"                                    \
+         "pushl 40(%%eax)\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "pushl 48(%%eax)\n\t"                                    \
+         "pushl 44(%%eax)\n\t"                                    \
+         "pushl 40(%%eax)\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_x86_linux || PLAT_x86_darwin || PLAT_x86_solaris */
+
+/* ---------------- amd64-{linux,darwin,solaris} --------------- */
+
+#if defined(PLAT_amd64_linux)  ||  defined(PLAT_amd64_darwin) \
+    ||  defined(PLAT_amd64_solaris)
+
+/* ARGREGS: rdi rsi rdx rcx r8 r9 (the rest on stack in R-to-L order) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS /*"rax",*/ "rcx", "rdx", "rsi",       \
+                            "rdi", "r8", "r9", "r10", "r11"
+
+/* This is all pretty complex.  It's so as to make stack unwinding
+   work reliably.  See bug 243270.  The basic problem is the sub and
+   add of 128 of %rsp in all of the following macros.  If gcc believes
+   the CFA is in %rsp, then unwinding may fail, because what's at the
+   CFA is not what gcc "expected" when it constructs the CFIs for the
+   places where the macros are instantiated.
+
+   But we can't just add a CFI annotation to increase the CFA offset
+   by 128, to match the sub of 128 from %rsp, because we don't know
+   whether gcc has chosen %rsp as the CFA at that point, or whether it
+   has chosen some other register (eg, %rbp).  In the latter case,
+   adding a CFI annotation to change the CFA offset is simply wrong.
+
+   So the solution is to get hold of the CFA using
+   __builtin_dwarf_cfa(), put it in a known register, and add a
+   CFI annotation to say what the register is.  We choose %rbp for
+   this (perhaps perversely), because:
+
+   (1) %rbp is already subject to unwinding.  If a new register was
+       chosen then the unwinder would have to unwind it in all stack
+       traces, which is expensive, and
+
+   (2) %rbp is already subject to precise exception updates in the
+       JIT.  If a new register was chosen, we'd have to have precise
+       exceptions for it too, which reduces performance of the
+       generated code.
+
+   However .. one extra complication.  We can't just whack the result
+   of __builtin_dwarf_cfa() into %rbp and then add %rbp to the
+   list of trashed registers at the end of the inline assembly
+   fragments; gcc won't allow %rbp to appear in that list.  Hence
+   instead we need to stash %rbp in %r15 for the duration of the asm,
+   and say that %r15 is trashed instead.  gcc seems happy to go with
+   that.
+
+   Oh .. and this all needs to be conditionalised so that it is
+   unchanged from before this commit, when compiled with older gccs
+   that don't support __builtin_dwarf_cfa.  Furthermore, since
+   this header file is freestanding, it has to be independent of
+   config.h, and so the following conditionalisation cannot depend on
+   configure time checks.
+
+   Although it's not clear from
+   'defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)',
+   this expression excludes Darwin.
+   .cfi directives in Darwin assembly appear to be completely
+   different and I haven't investigated how they work.
+
+   For even more entertainment value, note we have to use the
+   completely undocumented __builtin_dwarf_cfa(), which appears to
+   really compute the CFA, whereas __builtin_frame_address(0) claims
+   to but actually doesn't.  See
+   https://bugs.kde.org/show_bug.cgi?id=243270#c47
+*/
+#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)
+#  define __FRAME_POINTER                                         \
+      ,"r"(__builtin_dwarf_cfa())
+#  define VALGRIND_CFI_PROLOGUE                                   \
+      "movq %%rbp, %%r15\n\t"                                     \
+      "movq %2, %%rbp\n\t"                                        \
+      ".cfi_remember_state\n\t"                                   \
+      ".cfi_def_cfa rbp, 0\n\t"
+#  define VALGRIND_CFI_EPILOGUE                                   \
+      "movq %%r15, %%rbp\n\t"                                     \
+      ".cfi_restore_state\n\t"
+#else
+#  define __FRAME_POINTER
+#  define VALGRIND_CFI_PROLOGUE
+#  define VALGRIND_CFI_EPILOGUE
+#endif
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "movq %%rsp,%%r14\n\t"               \
+      "andq $0xfffffffffffffff0,%%rsp\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "movq %%r14,%%rsp\n\t"
+
+/* These CALL_FN_ macros assume that on amd64-linux, sizeof(unsigned
+   long) == 8. */
+
+/* NB 9 Sept 07.  There is a nasty kludge here in all these CALL_FN_
+   macros.  In order not to trash the stack redzone, we need to drop
+   %rsp by 128 before the hidden call, and restore afterwards.  The
+   nastyness is that it is only by luck that the stack still appears
+   to be unwindable during the hidden call - since then the behaviour
+   of any routine using this macro does not match what the CFI data
+   says.  Sigh.
+
+   Why is this important?  Imagine that a wrapper has a stack
+   allocated local, and passes to the hidden call, a pointer to it.
+   Because gcc does not know about the hidden call, it may allocate
+   that local in the redzone.  Unfortunately the hidden call may then
+   trash it before it comes to use it.  So we must step clear of the
+   redzone, for the duration of the hidden call, to make it safe.
+
+   Probably the same problem afflicts the other redzone-style ABIs too
+   (ppc64-linux); but for those, the stack is
+   self describing (none of this CFI nonsense) so at least messing
+   with the stack pointer doesn't give a danger of non-unwindable
+   stack. */
+
+#define CALL_FN_W_v(lval, orig)                                        \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[1];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                                  \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[2];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                            \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[3];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                      \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[4];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)                \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[5];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)             \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[6];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)        \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[7];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,        \
+                                 arg7)                                 \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[8];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $136,%%rsp\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,        \
+                                 arg7,arg8)                            \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[9];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,        \
+                                 arg7,arg8,arg9)                       \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[10];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $136,%%rsp\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,       \
+                                  arg7,arg8,arg9,arg10)                \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[11];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      _argvec[10] = (unsigned long)(arg10);                            \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "pushq 80(%%rax)\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,       \
+                                  arg7,arg8,arg9,arg10,arg11)          \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[12];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      _argvec[10] = (unsigned long)(arg10);                            \
+      _argvec[11] = (unsigned long)(arg11);                            \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $136,%%rsp\n\t"                                         \
+         "pushq 88(%%rax)\n\t"                                         \
+         "pushq 80(%%rax)\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,       \
+                                arg7,arg8,arg9,arg10,arg11,arg12)      \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[13];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      _argvec[10] = (unsigned long)(arg10);                            \
+      _argvec[11] = (unsigned long)(arg11);                            \
+      _argvec[12] = (unsigned long)(arg12);                            \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "pushq 96(%%rax)\n\t"                                         \
+         "pushq 88(%%rax)\n\t"                                         \
+         "pushq 80(%%rax)\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */
+
+/* ------------------------ ppc32-linux ------------------------ */
+
+#if defined(PLAT_ppc32_linux)
+
+/* This is useful for finding out about the on-stack stuff:
+
+   extern int f9  ( int,int,int,int,int,int,int,int,int );
+   extern int f10 ( int,int,int,int,int,int,int,int,int,int );
+   extern int f11 ( int,int,int,int,int,int,int,int,int,int,int );
+   extern int f12 ( int,int,int,int,int,int,int,int,int,int,int,int );
+
+   int g9 ( void ) {
+      return f9(11,22,33,44,55,66,77,88,99);
+   }
+   int g10 ( void ) {
+      return f10(11,22,33,44,55,66,77,88,99,110);
+   }
+   int g11 ( void ) {
+      return f11(11,22,33,44,55,66,77,88,99,110,121);
+   }
+   int g12 ( void ) {
+      return f12(11,22,33,44,55,66,77,88,99,110,121,132);
+   }
+*/
+
+/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS                                       \
+   "lr", "ctr", "xer",                                            \
+   "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7",        \
+   "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10",   \
+   "r11", "r12", "r13"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "mr 28,1\n\t"                        \
+      "rlwinm 1,1,0,0,27\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mr 1,28\n\t"
+
+/* These CALL_FN_ macros assume that on ppc32-linux, 
+   sizeof(unsigned long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-16\n\t"                                       \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      _argvec[10] = (unsigned long)arg10;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-16\n\t"                                       \
+         /* arg10 */                                              \
+         "lwz 3,40(11)\n\t"                                       \
+         "stw 3,12(1)\n\t"                                        \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      _argvec[10] = (unsigned long)arg10;                         \
+      _argvec[11] = (unsigned long)arg11;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-32\n\t"                                       \
+         /* arg11 */                                              \
+         "lwz 3,44(11)\n\t"                                       \
+         "stw 3,16(1)\n\t"                                        \
+         /* arg10 */                                              \
+         "lwz 3,40(11)\n\t"                                       \
+         "stw 3,12(1)\n\t"                                        \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                arg7,arg8,arg9,arg10,arg11,arg12) \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      _argvec[10] = (unsigned long)arg10;                         \
+      _argvec[11] = (unsigned long)arg11;                         \
+      _argvec[12] = (unsigned long)arg12;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-32\n\t"                                       \
+         /* arg12 */                                              \
+         "lwz 3,48(11)\n\t"                                       \
+         "stw 3,20(1)\n\t"                                        \
+         /* arg11 */                                              \
+         "lwz 3,44(11)\n\t"                                       \
+         "stw 3,16(1)\n\t"                                        \
+         /* arg10 */                                              \
+         "lwz 3,40(11)\n\t"                                       \
+         "stw 3,12(1)\n\t"                                        \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_ppc32_linux */
+
+/* ------------------------ ppc64-linux ------------------------ */
+
+#if defined(PLAT_ppc64be_linux)
+
+/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS                                       \
+   "lr", "ctr", "xer",                                            \
+   "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7",        \
+   "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10",         \
+   "r11", "r12", "r13"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "mr 28,1\n\t"                        \
+      "rldicr 1,1,0,59\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mr 1,28\n\t"
+
+/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned
+   long) == 8. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+0];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1] = (unsigned long)_orig.r2;                       \
+      _argvec[2] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+1];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+2];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+3];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+4];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+5];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+6];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+7];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+8];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+9];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+10];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg10 */                                              \
+         "ld  3,80(11)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+11];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg11 */                                              \
+         "ld  3,88(11)\n\t"                                       \
+         "std 3,128(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(11)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                arg7,arg8,arg9,arg10,arg11,arg12) \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+12];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      _argvec[2+12] = (unsigned long)arg12;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg12 */                                              \
+         "ld  3,96(11)\n\t"                                       \
+         "std 3,136(1)\n\t"                                       \
+         /* arg11 */                                              \
+         "ld  3,88(11)\n\t"                                       \
+         "std 3,128(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(11)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_ppc64be_linux */
+
+/* ------------------------- ppc64le-linux ----------------------- */
+#if defined(PLAT_ppc64le_linux)
+
+/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS                                       \
+   "lr", "ctr", "xer",                                            \
+   "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7",        \
+   "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10",         \
+   "r11", "r12", "r13"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "mr 28,1\n\t"                        \
+      "rldicr 1,1,0,59\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mr 1,28\n\t"
+
+/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned
+   long) == 8. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+0];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1] = (unsigned long)_orig.r2;                       \
+      _argvec[2] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+1];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+2];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+3];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+4];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+5];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+6];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+7];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+8];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+9];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+10];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg10 */                                              \
+         "ld  3,80(12)\n\t"                                       \
+         "std 3,104(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+11];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg11 */                                              \
+         "ld  3,88(12)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(12)\n\t"                                       \
+         "std 3,104(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                arg7,arg8,arg9,arg10,arg11,arg12) \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+12];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      _argvec[2+12] = (unsigned long)arg12;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg12 */                                              \
+         "ld  3,96(12)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg11 */                                              \
+         "ld  3,88(12)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(12)\n\t"                                       \
+         "std 3,104(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_ppc64le_linux */
+
+/* ------------------------- arm-linux ------------------------- */
+
+#if defined(PLAT_arm_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "r0", "r1", "r2", "r3","r4", "r12", "r14"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+/* This is a bit tricky.  We store the original stack pointer in r10
+   as it is callee-saves.  gcc doesn't allow the use of r11 for some
+   reason.  Also, we can't directly "bic" the stack pointer in thumb
+   mode since r13 isn't an allowed register number in that context.
+   So use r4 as a temporary, since that is about to get trashed
+   anyway, just after each use of this macro.  Side effect is we need
+   to be very careful about any future changes, since
+   VALGRIND_ALIGN_STACK simply assumes r4 is usable. */
+#define VALGRIND_ALIGN_STACK               \
+      "mov r10, sp\n\t"                    \
+      "mov r4,  sp\n\t"                    \
+      "bic r4,  r4, #7\n\t"                \
+      "mov sp,  r4\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mov sp,  r10\n\t"
+
+/* These CALL_FN_ macros assume that on arm-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "push {r0} \n\t"                                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "push {r0, r1} \n\t"                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "push {r0, r1, r2} \n\t"                                 \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "push {r0, r1, r2, r3} \n\t"                             \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #40] \n\t"                                 \
+         "push {r0} \n\t"                                         \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #40] \n\t"                                 \
+         "ldr r1, [%1, #44] \n\t"                                 \
+         "push {r0, r1} \n\t"                                     \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #40] \n\t"                                 \
+         "ldr r1, [%1, #44] \n\t"                                 \
+         "ldr r2, [%1, #48] \n\t"                                 \
+         "push {r0, r1, r2} \n\t"                                 \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_arm_linux */
+
+/* ------------------------ arm64-linux ------------------------ */
+
+#if defined(PLAT_arm64_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS \
+     "x0", "x1", "x2", "x3","x4", "x5", "x6", "x7", "x8", "x9",   \
+     "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17",      \
+     "x18", "x19", "x20", "x30",                                  \
+     "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",  \
+     "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",      \
+     "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",      \
+     "v26", "v27", "v28", "v29", "v30", "v31"
+
+/* x21 is callee-saved, so we can use it to save and restore SP around
+   the hidden call. */
+#define VALGRIND_ALIGN_STACK               \
+      "mov x21, sp\n\t"                    \
+      "bic sp, x21, #15\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mov sp,  x21\n\t"
+
+/* These CALL_FN_ macros assume that on arm64-linux,
+   sizeof(unsigned long) == 8. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x20 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x20 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1, #80] \n\t"                                 \
+         "str x8, [sp, #8]  \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x30 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1, #80] \n\t"                                 \
+         "str x8, [sp, #8]  \n\t"                                 \
+         "ldr x8, [%1, #88] \n\t"                                 \
+         "str x8, [sp, #16] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11,     \
+                                  arg12)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x30 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1, #80] \n\t"                                 \
+         "str x8, [sp, #8]  \n\t"                                 \
+         "ldr x8, [%1, #88] \n\t"                                 \
+         "str x8, [sp, #16] \n\t"                                 \
+         "ldr x8, [%1, #96] \n\t"                                 \
+         "str x8, [sp, #24] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_arm64_linux */
+
+/* ------------------------- s390x-linux ------------------------- */
+
+#if defined(PLAT_s390x_linux)
+
+/* Similar workaround as amd64 (see above), but we use r11 as frame
+   pointer and save the old r11 in r7. r11 might be used for
+   argvec, therefore we copy argvec in r1 since r1 is clobbered
+   after the call anyway.  */
+#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)
+#  define __FRAME_POINTER                                         \
+      ,"d"(__builtin_dwarf_cfa())
+#  define VALGRIND_CFI_PROLOGUE                                   \
+      ".cfi_remember_state\n\t"                                   \
+      "lgr 1,%1\n\t" /* copy the argvec pointer in r1 */          \
+      "lgr 7,11\n\t"                                              \
+      "lgr 11,%2\n\t"                                             \
+      ".cfi_def_cfa r11, 0\n\t"
+#  define VALGRIND_CFI_EPILOGUE                                   \
+      "lgr 11, 7\n\t"                                             \
+      ".cfi_restore_state\n\t"
+#else
+#  define __FRAME_POINTER
+#  define VALGRIND_CFI_PROLOGUE                                   \
+      "lgr 1,%1\n\t"
+#  define VALGRIND_CFI_EPILOGUE
+#endif
+
+/* Nb: On s390 the stack pointer is properly aligned *at all times*
+   according to the s390 GCC maintainer. (The ABI specification is not
+   precise in this regard.) Therefore, VALGRIND_ALIGN_STACK and
+   VALGRIND_RESTORE_STACK are not defined here. */
+
+/* These regs are trashed by the hidden call. Note that we overwrite
+   r14 in s390_irgen_noredir (VEX/priv/guest_s390_irgen.c) to give the
+   function a proper return address. All others are ABI defined call
+   clobbers. */
+#if defined(__VX__) || defined(__S390_VX__)
+#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14",   \
+      "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",             \
+      "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",       \
+      "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",     \
+      "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+#else
+#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14",   \
+      "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7"
+#endif
+
+/* Nb: Although r11 is modified in the asm snippets below (inside 
+   VALGRIND_CFI_PROLOGUE) it is not listed in the clobber section, for
+   two reasons:
+   (1) r11 is restored in VALGRIND_CFI_EPILOGUE, so effectively it is not
+       modified
+   (2) GCC will complain that r11 cannot appear inside a clobber section,
+       when compiled with -O -fno-omit-frame-pointer
+ */
+
+#define CALL_FN_W_v(lval, orig)                                  \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long  _argvec[1];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 1, 0(1)\n\t"  /* target->r1 */                      \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "d" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+/* The call abi has the arguments in r2-r6 and stack */
+#define CALL_FN_W_W(lval, orig, arg1)                            \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[2];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1, arg2)                     \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[3];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1, arg2, arg3)              \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[4];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1, arg2, arg3, arg4)       \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[5];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1, arg2, arg3, arg4, arg5)   \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[6];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6)                                       \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[7];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-168\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,168\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6, arg7)                                 \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[8];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-176\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,176\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6, arg7 ,arg8)                           \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[9];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-184\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,184\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6, arg7 ,arg8, arg9)                     \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[10];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-192\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,192\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1, arg2, arg3, arg4, arg5,  \
+                     arg6, arg7 ,arg8, arg9, arg10)              \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[11];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      _argvec[10] = (unsigned long)arg10;                        \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-200\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "mvc 192(8,15), 80(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,200\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1, arg2, arg3, arg4, arg5,  \
+                     arg6, arg7 ,arg8, arg9, arg10, arg11)       \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[12];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      _argvec[10] = (unsigned long)arg10;                        \
+      _argvec[11] = (unsigned long)arg11;                        \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-208\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "mvc 192(8,15), 80(1)\n\t"                              \
+         "mvc 200(8,15), 88(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,208\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1, arg2, arg3, arg4, arg5,  \
+                     arg6, arg7 ,arg8, arg9, arg10, arg11, arg12)\
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[13];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      _argvec[10] = (unsigned long)arg10;                        \
+      _argvec[11] = (unsigned long)arg11;                        \
+      _argvec[12] = (unsigned long)arg12;                        \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-216\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "mvc 192(8,15), 80(1)\n\t"                              \
+         "mvc 200(8,15), 88(1)\n\t"                              \
+         "mvc 208(8,15), 96(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,216\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+
+#endif /* PLAT_s390x_linux */
+
+/* ------------------------- mips32-linux ----------------------- */
+ 
+#if defined(PLAT_mips32_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6",       \
+"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \
+"$25", "$31"
+
+/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16\n\t"                                  \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+     volatile unsigned long _argvec[2];                           \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"   /* arg1*/                          \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory",  __CALLER_SAVED_REGS               \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 24\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 24 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 32\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "nop\n\t"                                                \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 32 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 32\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 32 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 40\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 40 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 40\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 40 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 48\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 40(%1) \n\t"                                     \
+         "sw $4, 36($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 48 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 48\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 40(%1) \n\t"                                     \
+         "sw $4, 36($29) \n\t"                                    \
+         "lw $4, 44(%1) \n\t"                                     \
+         "sw $4, 40($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 48 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 56\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 40(%1) \n\t"                                     \
+         "sw $4, 36($29) \n\t"                                    \
+         "lw $4, 44(%1) \n\t"                                     \
+         "sw $4, 40($29) \n\t"                                    \
+         "lw $4, 48(%1) \n\t"                                     \
+         "sw $4, 44($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 56 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_mips32_linux */
+
+/* ------------------------- nanomips-linux -------------------- */
+
+#if defined(PLAT_nanomips_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "$t4", "$t5", "$a0", "$a1", "$a2",     \
+"$a3", "$a4", "$a5", "$a6", "$a7", "$t0", "$t1", "$t2", "$t3",     \
+"$t8","$t9", "$at"
+
+/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         "lw $a5,24(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         "lw $a5,24(%1)\n\t"                                      \
+         "lw $a6,28(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         "lw $a5,24(%1)\n\t"                                      \
+         "lw $a6,28(%1)\n\t"                                      \
+         "lw $a7,32(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9,40(%1)        \n\t"                              \
+         "sw $t9, 4($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9,40(%1)        \n\t"                              \
+         "sw $t9, 4($sp)       \n\t"                              \
+         "lw $t9,44(%1)        \n\t"                              \
+         "sw $t9, 8($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9,40(%1)        \n\t"                              \
+         "sw $t9, 4($sp)       \n\t"                              \
+         "lw $t9,44(%1)        \n\t"                              \
+         "sw $t9, 8($sp)       \n\t"                              \
+         "lw $t9,48(%1)        \n\t"                              \
+         "sw $t9,12($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_nanomips_linux */
+
+/* ------------------------- mips64-linux ------------------------- */
+
+#if defined(PLAT_mips64_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6",       \
+"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \
+"$25", "$31"
+
+/* These CALL_FN_ macros assume that on mips64-linux,
+   sizeof(long long) == 8. */
+
+#define MIPS64_LONG2REG_CAST(x) ((long long)(long)x)
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[1];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      __asm__ volatile(                                           \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[2];                     \
+      volatile unsigned long long  _res;                          \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"   /* arg1*/                           \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[3];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = _orig.nraddr;                                  \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[4];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = _orig.nraddr;                                  \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[5];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[6];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[7];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[8];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[9];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[10];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 8\n\t"                                  \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 8\n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[11];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      _argvec[10] = MIPS64_LONG2REG_CAST(arg10);                  \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 16\n\t"                                 \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 80(%1)\n\t"                                      \
+         "sd $4, 8($29)\n\t"                                      \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 16\n\t"                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[12];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      _argvec[10] = MIPS64_LONG2REG_CAST(arg10);                  \
+      _argvec[11] = MIPS64_LONG2REG_CAST(arg11);                  \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 24\n\t"                                 \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 80(%1)\n\t"                                      \
+         "sd $4, 8($29)\n\t"                                      \
+         "ld $4, 88(%1)\n\t"                                      \
+         "sd $4, 16($29)\n\t"                                     \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 24\n\t"                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[13];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      _argvec[10] = MIPS64_LONG2REG_CAST(arg10);                  \
+      _argvec[11] = MIPS64_LONG2REG_CAST(arg11);                  \
+      _argvec[12] = MIPS64_LONG2REG_CAST(arg12);                  \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 32\n\t"                                 \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 80(%1)\n\t"                                      \
+         "sd $4, 8($29)\n\t"                                      \
+         "ld $4, 88(%1)\n\t"                                      \
+         "sd $4, 16($29)\n\t"                                     \
+         "ld $4, 96(%1)\n\t"                                      \
+         "sd $4, 24($29)\n\t"                                     \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 32\n\t"                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#endif /* PLAT_mips64_linux */
+
+/* ------------------------------------------------------------------ */
+/* ARCHITECTURE INDEPENDENT MACROS for CLIENT REQUESTS.               */
+/*                                                                    */
+/* ------------------------------------------------------------------ */
+
+/* Some request codes.  There are many more of these, but most are not
+   exposed to end-user view.  These are the public ones, all of the
+   form 0x1000 + small_number.
+
+   Core ones are in the range 0x00000000--0x0000ffff.  The non-public
+   ones start at 0x2000.
+*/
+
+/* These macros are used by tools -- they must be public, but don't
+   embed them into other programs. */
+#define VG_USERREQ_TOOL_BASE(a,b) \
+   ((unsigned int)(((a)&0xff) << 24 | ((b)&0xff) << 16))
+#define VG_IS_TOOL_USERREQ(a, b, v) \
+   (VG_USERREQ_TOOL_BASE(a,b) == ((v) & 0xffff0000))
+
+/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! 
+   This enum comprises an ABI exported by Valgrind to programs
+   which use client requests.  DO NOT CHANGE THE NUMERIC VALUES OF THESE
+   ENTRIES, NOR DELETE ANY -- add new ones at the end of the most
+   relevant group. */
+typedef
+   enum { VG_USERREQ__RUNNING_ON_VALGRIND  = 0x1001,
+          VG_USERREQ__DISCARD_TRANSLATIONS = 0x1002,
+
+          /* These allow any function to be called from the simulated
+             CPU but run on the real CPU.  Nb: the first arg passed to
+             the function is always the ThreadId of the running
+             thread!  So CLIENT_CALL0 actually requires a 1 arg
+             function, etc. */
+          VG_USERREQ__CLIENT_CALL0 = 0x1101,
+          VG_USERREQ__CLIENT_CALL1 = 0x1102,
+          VG_USERREQ__CLIENT_CALL2 = 0x1103,
+          VG_USERREQ__CLIENT_CALL3 = 0x1104,
+
+          /* Can be useful in regression testing suites -- eg. can
+             send Valgrind's output to /dev/null and still count
+             errors. */
+          VG_USERREQ__COUNT_ERRORS = 0x1201,
+
+          /* Allows the client program and/or gdbserver to execute a monitor
+             command. */
+          VG_USERREQ__GDB_MONITOR_COMMAND = 0x1202,
+
+          /* Allows the client program to change a dynamic command line
+             option.  */
+          VG_USERREQ__CLO_CHANGE = 0x1203,
+
+          /* These are useful and can be interpreted by any tool that
+             tracks malloc() et al, by using vg_replace_malloc.c. */
+          VG_USERREQ__MALLOCLIKE_BLOCK = 0x1301,
+          VG_USERREQ__RESIZEINPLACE_BLOCK = 0x130b,
+          VG_USERREQ__FREELIKE_BLOCK   = 0x1302,
+          /* Memory pool support. */
+          VG_USERREQ__CREATE_MEMPOOL   = 0x1303,
+          VG_USERREQ__DESTROY_MEMPOOL  = 0x1304,
+          VG_USERREQ__MEMPOOL_ALLOC    = 0x1305,
+          VG_USERREQ__MEMPOOL_FREE     = 0x1306,
+          VG_USERREQ__MEMPOOL_TRIM     = 0x1307,
+          VG_USERREQ__MOVE_MEMPOOL     = 0x1308,
+          VG_USERREQ__MEMPOOL_CHANGE   = 0x1309,
+          VG_USERREQ__MEMPOOL_EXISTS   = 0x130a,
+
+          /* Allow printfs to valgrind log. */
+          /* The first two pass the va_list argument by value, which
+             assumes it is the same size as or smaller than a UWord,
+             which generally isn't the case.  Hence are deprecated.
+             The second two pass the vargs by reference and so are
+             immune to this problem. */
+          /* both :: char* fmt, va_list vargs (DEPRECATED) */
+          VG_USERREQ__PRINTF           = 0x1401,
+          VG_USERREQ__PRINTF_BACKTRACE = 0x1402,
+          /* both :: char* fmt, va_list* vargs */
+          VG_USERREQ__PRINTF_VALIST_BY_REF = 0x1403,
+          VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF = 0x1404,
+
+          /* Stack support. */
+          VG_USERREQ__STACK_REGISTER   = 0x1501,
+          VG_USERREQ__STACK_DEREGISTER = 0x1502,
+          VG_USERREQ__STACK_CHANGE     = 0x1503,
+
+          /* Wine support */
+          VG_USERREQ__LOAD_PDB_DEBUGINFO = 0x1601,
+
+          /* Querying of debug info. */
+          VG_USERREQ__MAP_IP_TO_SRCLOC = 0x1701,
+
+          /* Disable/enable error reporting level.  Takes a single
+             Word arg which is the delta to this thread's error
+             disablement indicator.  Hence 1 disables or further
+             disables errors, and -1 moves back towards enablement.
+             Other values are not allowed. */
+          VG_USERREQ__CHANGE_ERR_DISABLEMENT = 0x1801,
+
+          /* Some requests used for Valgrind internal, such as
+             self-test or self-hosting. */
+          /* Initialise IR injection */
+          VG_USERREQ__VEX_INIT_FOR_IRI = 0x1901,
+          /* Used by Inner Valgrind to inform Outer Valgrind where to
+             find the list of inner guest threads */
+          VG_USERREQ__INNER_THREADS    = 0x1902
+   } Vg_ClientRequest;
+
+#if !defined(__GNUC__)
+#  define __extension__ /* */
+#endif
+
+
+/* Returns the number of Valgrinds this code is running under.  That
+   is, 0 if running natively, 1 if running under Valgrind, 2 if
+   running under Valgrind which is running under another Valgrind,
+   etc. */
+#define RUNNING_ON_VALGRIND                                           \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* if not */,         \
+                                    VG_USERREQ__RUNNING_ON_VALGRIND,  \
+                                    0, 0, 0, 0, 0)                    \
+
+
+/* Discard translation of code in the range [_qzz_addr .. _qzz_addr +
+   _qzz_len - 1].  Useful if you are debugging a JITter or some such,
+   since it provides a way to make sure valgrind will retranslate the
+   invalidated area.  Returns no value. */
+#define VALGRIND_DISCARD_TRANSLATIONS(_qzz_addr,_qzz_len)              \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DISCARD_TRANSLATIONS,  \
+                                    _qzz_addr, _qzz_len, 0, 0, 0)
+
+#define VALGRIND_INNER_THREADS(_qzz_addr)                               \
+   VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__INNER_THREADS,           \
+                                   _qzz_addr, 0, 0, 0, 0)
+
+
+/* These requests are for getting Valgrind itself to print something.
+   Possibly with a backtrace.  This is a really ugly hack.  The return value
+   is the number of characters printed, excluding the "**** " part at the
+   start and the backtrace (if present). */
+
+#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER)
+/* Modern GCC will optimize the static routine out if unused,
+   and unused attribute will shut down warnings about it.  */
+static int VALGRIND_PRINTF(const char *format, ...)
+   __attribute__((format(__printf__, 1, 2), __unused__));
+#endif
+static int
+#if defined(_MSC_VER)
+__inline
+#endif
+VALGRIND_PRINTF(const char *format, ...)
+{
+#if defined(NVALGRIND)
+   (void)format;
+   return 0;
+#else /* NVALGRIND */
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   uintptr_t _qzz_res;
+#else
+   unsigned long _qzz_res;
+#endif
+   va_list vargs;
+   va_start(vargs, format);
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_VALIST_BY_REF,
+                              (uintptr_t)format,
+                              (uintptr_t)&vargs,
+                              0, 0, 0);
+#else
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_VALIST_BY_REF,
+                              (unsigned long)format,
+                              (unsigned long)&vargs, 
+                              0, 0, 0);
+#endif
+   va_end(vargs);
+   return (int)_qzz_res;
+#endif /* NVALGRIND */
+}
+
+#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER)
+static int VALGRIND_PRINTF_BACKTRACE(const char *format, ...)
+   __attribute__((format(__printf__, 1, 2), __unused__));
+#endif
+static int
+#if defined(_MSC_VER)
+__inline
+#endif
+VALGRIND_PRINTF_BACKTRACE(const char *format, ...)
+{
+#if defined(NVALGRIND)
+   (void)format;
+   return 0;
+#else /* NVALGRIND */
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   uintptr_t _qzz_res;
+#else
+   unsigned long _qzz_res;
+#endif
+   va_list vargs;
+   va_start(vargs, format);
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF,
+                              (uintptr_t)format,
+                              (uintptr_t)&vargs,
+                              0, 0, 0);
+#else
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF,
+                              (unsigned long)format,
+                              (unsigned long)&vargs, 
+                              0, 0, 0);
+#endif
+   va_end(vargs);
+   return (int)_qzz_res;
+#endif /* NVALGRIND */
+}
+
+
+/* These requests allow control to move from the simulated CPU to the
+   real CPU, calling an arbitrary function.
+   
+   Note that the current ThreadId is inserted as the first argument.
+   So this call:
+
+     VALGRIND_NON_SIMD_CALL2(f, arg1, arg2)
+
+   requires f to have this signature:
+
+     Word f(Word tid, Word arg1, Word arg2)
+
+   where "Word" is a word-sized type.
+
+   Note that these client requests are not entirely reliable.  For example,
+   if you call a function with them that subsequently calls printf(),
+   there's a high chance Valgrind will crash.  Generally, your prospects of
+   these working are made higher if the called function does not refer to
+   any global variables, and does not refer to any libc or other functions
+   (printf et al).  Any kind of entanglement with libc or dynamic linking is
+   likely to have a bad outcome, for tricky reasons which we've grappled
+   with a lot in the past.
+*/
+#define VALGRIND_NON_SIMD_CALL0(_qyy_fn)                          \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,       \
+                                    VG_USERREQ__CLIENT_CALL0,     \
+                                    _qyy_fn,                      \
+                                    0, 0, 0, 0)
+
+#define VALGRIND_NON_SIMD_CALL1(_qyy_fn, _qyy_arg1)                    \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,            \
+                                    VG_USERREQ__CLIENT_CALL1,          \
+                                    _qyy_fn,                           \
+                                    _qyy_arg1, 0, 0, 0)
+
+#define VALGRIND_NON_SIMD_CALL2(_qyy_fn, _qyy_arg1, _qyy_arg2)         \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,            \
+                                    VG_USERREQ__CLIENT_CALL2,          \
+                                    _qyy_fn,                           \
+                                    _qyy_arg1, _qyy_arg2, 0, 0)
+
+#define VALGRIND_NON_SIMD_CALL3(_qyy_fn, _qyy_arg1, _qyy_arg2, _qyy_arg3) \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,             \
+                                    VG_USERREQ__CLIENT_CALL3,           \
+                                    _qyy_fn,                            \
+                                    _qyy_arg1, _qyy_arg2,               \
+                                    _qyy_arg3, 0)
+
+
+/* Counts the number of errors that have been recorded by a tool.  Nb:
+   the tool must record the errors with VG_(maybe_record_error)() or
+   VG_(unique_error)() for them to be counted. */
+#define VALGRIND_COUNT_ERRORS                                     \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(                    \
+                               0 /* default return */,            \
+                               VG_USERREQ__COUNT_ERRORS,          \
+                               0, 0, 0, 0, 0)
+
+/* Several Valgrind tools (Memcheck, Massif, Helgrind, DRD) rely on knowing
+   when heap blocks are allocated in order to give accurate results.  This
+   happens automatically for the standard allocator functions such as
+   malloc(), calloc(), realloc(), memalign(), new, new[], free(), delete,
+   delete[], etc.
+
+   But if your program uses a custom allocator, this doesn't automatically
+   happen, and Valgrind will not do as well.  For example, if you allocate
+   superblocks with mmap() and then allocates chunks of the superblocks, all
+   Valgrind's observations will be at the mmap() level and it won't know that
+   the chunks should be considered separate entities.  In Memcheck's case,
+   that means you probably won't get heap block overrun detection (because
+   there won't be redzones marked as unaddressable) and you definitely won't
+   get any leak detection.
+
+   The following client requests allow a custom allocator to be annotated so
+   that it can be handled accurately by Valgrind.
+
+   VALGRIND_MALLOCLIKE_BLOCK marks a region of memory as having been allocated
+   by a malloc()-like function.  For Memcheck (an illustrative case), this
+   does two things:
+
+   - It records that the block has been allocated.  This means any addresses
+     within the block mentioned in error messages will be
+     identified as belonging to the block.  It also means that if the block
+     isn't freed it will be detected by the leak checker.
+
+   - It marks the block as being addressable and undefined (if 'is_zeroed' is
+     not set), or addressable and defined (if 'is_zeroed' is set).  This
+     controls how accesses to the block by the program are handled.
+   
+   'addr' is the start of the usable block (ie. after any
+   redzone), 'sizeB' is its size.  'rzB' is the redzone size if the allocator
+   can apply redzones -- these are blocks of padding at the start and end of
+   each block.  Adding redzones is recommended as it makes it much more likely
+   Valgrind will spot block overruns.  `is_zeroed' indicates if the memory is
+   zeroed (or filled with another predictable value), as is the case for
+   calloc().
+   
+   VALGRIND_MALLOCLIKE_BLOCK should be put immediately after the point where a
+   heap block -- that will be used by the client program -- is allocated.
+   It's best to put it at the outermost level of the allocator if possible;
+   for example, if you have a function my_alloc() which calls
+   internal_alloc(), and the client request is put inside internal_alloc(),
+   stack traces relating to the heap block will contain entries for both
+   my_alloc() and internal_alloc(), which is probably not what you want.
+
+   For Memcheck users: if you use VALGRIND_MALLOCLIKE_BLOCK to carve out
+   custom blocks from within a heap block, B, that has been allocated with
+   malloc/calloc/new/etc, then block B will be *ignored* during leak-checking
+   -- the custom blocks will take precedence.
+
+   VALGRIND_FREELIKE_BLOCK is the partner to VALGRIND_MALLOCLIKE_BLOCK.  For
+   Memcheck, it does two things:
+
+   - It records that the block has been deallocated.  This assumes that the
+     block was annotated as having been allocated via
+     VALGRIND_MALLOCLIKE_BLOCK.  Otherwise, an error will be issued.
+
+   - It marks the block as being unaddressable.
+
+   VALGRIND_FREELIKE_BLOCK should be put immediately after the point where a
+   heap block is deallocated.
+
+   VALGRIND_RESIZEINPLACE_BLOCK informs a tool about reallocation. For
+   Memcheck, it does four things:
+
+   - It records that the size of a block has been changed.  This assumes that
+     the block was annotated as having been allocated via
+     VALGRIND_MALLOCLIKE_BLOCK.  Otherwise, an error will be issued.
+
+   - If the block shrunk, it marks the freed memory as being unaddressable.
+
+   - If the block grew, it marks the new area as undefined and defines a red
+     zone past the end of the new block.
+
+   - The V-bits of the overlap between the old and the new block are preserved.
+
+   VALGRIND_RESIZEINPLACE_BLOCK should be put after allocation of the new block
+   and before deallocation of the old block.
+
+   In many cases, these three client requests will not be enough to get your
+   allocator working well with Memcheck.  More specifically, if your allocator
+   writes to freed blocks in any way then a VALGRIND_MAKE_MEM_UNDEFINED call
+   will be necessary to mark the memory as addressable just before the zeroing
+   occurs, otherwise you'll get a lot of invalid write errors.  For example,
+   you'll need to do this if your allocator recycles freed blocks, but it
+   zeroes them before handing them back out (via VALGRIND_MALLOCLIKE_BLOCK).
+   Alternatively, if your allocator reuses freed blocks for allocator-internal
+   data structures, VALGRIND_MAKE_MEM_UNDEFINED calls will also be necessary.
+
+   Really, what's happening is a blurring of the lines between the client
+   program and the allocator... after VALGRIND_FREELIKE_BLOCK is called, the
+   memory should be considered unaddressable to the client program, but the
+   allocator knows more than the rest of the client program and so may be able
+   to safely access it.  Extra client requests are necessary for Valgrind to
+   understand the distinction between the allocator and the rest of the
+   program.
+
+   Ignored if addr == 0.
+*/
+#define VALGRIND_MALLOCLIKE_BLOCK(addr, sizeB, rzB, is_zeroed)          \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MALLOCLIKE_BLOCK,       \
+                                    addr, sizeB, rzB, is_zeroed, 0)
+
+/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details.
+   Ignored if addr == 0.
+*/
+#define VALGRIND_RESIZEINPLACE_BLOCK(addr, oldSizeB, newSizeB, rzB)     \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__RESIZEINPLACE_BLOCK,    \
+                                    addr, oldSizeB, newSizeB, rzB, 0)
+
+/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details.
+   Ignored if addr == 0.
+*/
+#define VALGRIND_FREELIKE_BLOCK(addr, rzB)                              \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__FREELIKE_BLOCK,         \
+                                    addr, rzB, 0, 0, 0)
+
+/* Create a memory pool. */
+#define VALGRIND_CREATE_MEMPOOL(pool, rzB, is_zeroed)             \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL,   \
+                                    pool, rzB, is_zeroed, 0, 0)
+
+/* Create a memory pool with some flags specifying extended behaviour.
+   When flags is zero, the behaviour is identical to VALGRIND_CREATE_MEMPOOL.
+   
+   The flag VALGRIND_MEMPOOL_METAPOOL specifies that the pieces of memory 
+   associated with the pool using VALGRIND_MEMPOOL_ALLOC  will be used
+   by the application as superblocks to dole out MALLOC_LIKE blocks using
+   VALGRIND_MALLOCLIKE_BLOCK. In other words, a meta pool is a "2 levels"
+   pool : first level is the blocks described by VALGRIND_MEMPOOL_ALLOC.
+   The second level blocks are described using VALGRIND_MALLOCLIKE_BLOCK.
+   Note that the association between the pool and the second level blocks
+   is implicit : second level blocks will be located inside first level
+   blocks. It is necessary to use the VALGRIND_MEMPOOL_METAPOOL flag
+   for such 2 levels pools, as otherwise valgrind will detect overlapping
+   memory blocks, and will abort execution (e.g. during leak search).
+
+   Such a meta pool can also be marked as an 'auto free' pool using the flag
+   VALGRIND_MEMPOOL_AUTO_FREE, which must be OR-ed together with the
+   VALGRIND_MEMPOOL_METAPOOL. For an 'auto free' pool, VALGRIND_MEMPOOL_FREE
+   will automatically free the second level blocks that are contained
+   inside the first level block freed with VALGRIND_MEMPOOL_FREE.
+   In other words, calling VALGRIND_MEMPOOL_FREE will cause implicit calls
+   to VALGRIND_FREELIKE_BLOCK for all the second level blocks included
+   in the first level block.
+   Note: it is an error to use the VALGRIND_MEMPOOL_AUTO_FREE flag
+   without the VALGRIND_MEMPOOL_METAPOOL flag.
+*/
+#define VALGRIND_MEMPOOL_AUTO_FREE  1
+#define VALGRIND_MEMPOOL_METAPOOL   2
+#define VALGRIND_CREATE_MEMPOOL_EXT(pool, rzB, is_zeroed, flags)        \
+   VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL,          \
+                                   pool, rzB, is_zeroed, flags, 0)
+
+/* Destroy a memory pool. */
+#define VALGRIND_DESTROY_MEMPOOL(pool)                            \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DESTROY_MEMPOOL,  \
+                                    pool, 0, 0, 0, 0)
+
+/* Associate a piece of memory with a memory pool. */
+#define VALGRIND_MEMPOOL_ALLOC(pool, addr, size)                  \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_ALLOC,    \
+                                    pool, addr, size, 0, 0)
+
+/* Disassociate a piece of memory from a memory pool. */
+#define VALGRIND_MEMPOOL_FREE(pool, addr)                         \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_FREE,     \
+                                    pool, addr, 0, 0, 0)
+
+/* Disassociate any pieces outside a particular range. */
+#define VALGRIND_MEMPOOL_TRIM(pool, addr, size)                   \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_TRIM,     \
+                                    pool, addr, size, 0, 0)
+
+/* Resize and/or move a piece associated with a memory pool. */
+#define VALGRIND_MOVE_MEMPOOL(poolA, poolB)                       \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MOVE_MEMPOOL,     \
+                                    poolA, poolB, 0, 0, 0)
+
+/* Resize and/or move a piece associated with a memory pool. */
+#define VALGRIND_MEMPOOL_CHANGE(pool, addrA, addrB, size)         \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_CHANGE,   \
+                                    pool, addrA, addrB, size, 0)
+
+/* Return 1 if a mempool exists, else 0. */
+#define VALGRIND_MEMPOOL_EXISTS(pool)                             \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                  \
+                               VG_USERREQ__MEMPOOL_EXISTS,        \
+                               pool, 0, 0, 0, 0)
+
+/* Mark a piece of memory as being a stack. Returns a stack id.
+   start is the lowest addressable stack byte, end is the highest
+   addressable stack byte. */
+#define VALGRIND_STACK_REGISTER(start, end)                       \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                  \
+                               VG_USERREQ__STACK_REGISTER,        \
+                               start, end, 0, 0, 0)
+
+/* Unmark the piece of memory associated with a stack id as being a
+   stack. */
+#define VALGRIND_STACK_DEREGISTER(id)                             \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_DEREGISTER, \
+                                    id, 0, 0, 0, 0)
+
+/* Change the start and end address of the stack id.
+   start is the new lowest addressable stack byte, end is the new highest
+   addressable stack byte. */
+#define VALGRIND_STACK_CHANGE(id, start, end)                     \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_CHANGE,     \
+                                    id, start, end, 0, 0)
+
+/* Load PDB debug info for Wine PE image_map. */
+#define VALGRIND_LOAD_PDB_DEBUGINFO(fd, ptr, total_size, delta)     \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__LOAD_PDB_DEBUGINFO, \
+                                    fd, ptr, total_size, delta, 0)
+
+/* Map a code address to a source file name and line number.  buf64
+   must point to a 64-byte buffer in the caller's address space.  The
+   result will be dumped in there and is guaranteed to be zero
+   terminated.  If no info is found, the first byte is set to zero. */
+#define VALGRIND_MAP_IP_TO_SRCLOC(addr, buf64)                    \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                  \
+                               VG_USERREQ__MAP_IP_TO_SRCLOC,      \
+                               addr, buf64, 0, 0, 0)
+
+/* Disable error reporting for this thread.  Behaves in a stack like
+   way, so you can safely call this multiple times provided that
+   VALGRIND_ENABLE_ERROR_REPORTING is called the same number of times
+   to re-enable reporting.  The first call of this macro disables
+   reporting.  Subsequent calls have no effect except to increase the
+   number of VALGRIND_ENABLE_ERROR_REPORTING calls needed to re-enable
+   reporting.  Child threads do not inherit this setting from their
+   parents -- they are always created with reporting enabled. */
+#define VALGRIND_DISABLE_ERROR_REPORTING                                \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \
+                                    1, 0, 0, 0, 0)
+
+/* Re-enable error reporting, as per comments on
+   VALGRIND_DISABLE_ERROR_REPORTING. */
+#define VALGRIND_ENABLE_ERROR_REPORTING                                 \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \
+                                    -1, 0, 0, 0, 0)
+
+/* Execute a monitor command from the client program.
+   If a connection is opened with GDB, the output will be sent
+   according to the output mode set for vgdb.
+   If no connection is opened, output will go to the log output.
+   Returns 1 if command not recognised, 0 otherwise. */
+#define VALGRIND_MONITOR_COMMAND(command)                               \
+   VALGRIND_DO_CLIENT_REQUEST_EXPR(0, VG_USERREQ__GDB_MONITOR_COMMAND, \
+                                   command, 0, 0, 0, 0)
+
+
+/* Change the value of a dynamic command line option.
+   Note that unknown or not dynamically changeable options
+   will cause a warning message to be output.  */
+#define VALGRIND_CLO_CHANGE(option)                           \
+   VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CLO_CHANGE, \
+                                   option, 0, 0, 0, 0)
+
+
+#undef PLAT_x86_darwin
+#undef PLAT_amd64_darwin
+#undef PLAT_x86_win32
+#undef PLAT_amd64_win64
+#undef PLAT_x86_linux
+#undef PLAT_amd64_linux
+#undef PLAT_ppc32_linux
+#undef PLAT_ppc64be_linux
+#undef PLAT_ppc64le_linux
+#undef PLAT_arm_linux
+#undef PLAT_s390x_linux
+#undef PLAT_mips32_linux
+#undef PLAT_mips64_linux
+#undef PLAT_nanomips_linux
+#undef PLAT_x86_solaris
+#undef PLAT_amd64_solaris
+
+#endif   /* __VALGRIND_H */
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fb6f173d8013dea8f3b2d89c1a0c60980abcfcd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6489c5aaf97e5e603e01d9060d0617d670ced084
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3097daa4bcc4b55ab748090c332e2d44c64a6fa
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b7ee59d165e50c90e3303519f297a172d4b4071
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccad97f5101f05066657b97a13df4879170cade6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db99accd6ce8717f9786ae85747b29524bb2610c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa1a3fb84cf4132b8fcf3a9e58ae70b3b5ccff4b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bafdfb30df83b46eab8392acba48519f57b630a9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5208dd499845173f47f40f5dd46ccd89ddcb9103
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f19ab52cb2b6d1c485db025aaf32edef801507b
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbb6625e8a003652ebe8e27a3b3f11bb6c7b5139
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8897218ef00a35e14f5b2ab0543a2cb356c357cd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c46828e792d34c1a5c77cb51862f95a652b902f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dc4f87484687eecf222923e7e31637a71951054
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..e198aa16caa66105c0b2009ed8da1e655effe151
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py
@@ -0,0 +1,484 @@
+# mypy: allow-untyped-defs
+# Taking reference from official Python typing
+# https://github.com/python/cpython/blob/master/Lib/typing.py
+
+import collections
+import functools
+import numbers
+import sys
+
+# Please check [Note: TypeMeta and TypeAlias]
+# In case of metaclass conflict due to ABCMeta or _ProtocolMeta
+# For Python 3.9, only Protocol in typing uses metaclass
+from abc import ABCMeta
+from collections.abc import Iterator
+
+# TODO: Use TypeAlias when Python 3.6 is deprecated
+from typing import (  # type: ignore[attr-defined]
+    _eval_type,
+    _GenericAlias,
+    _tp_cache,
+    _type_check,
+    _type_repr,
+    Any,
+    ForwardRef,
+    Generic,
+    get_type_hints,
+    TypeVar,
+    Union,
+)
+
+from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator
+
+
+class GenericMeta(ABCMeta):  # type: ignore[no-redef]
+    pass
+
+
+class Integer(numbers.Integral):
+    pass
+
+
+class Boolean(numbers.Integral):
+    pass
+
+
+# Python 'type' object is not subscriptable
+# Tuple[int, List, dict] -> valid
+# tuple[int, list, dict] -> invalid
+# Map Python 'type' to abstract base class
+TYPE2ABC = {
+    bool: Boolean,
+    int: Integer,
+    float: numbers.Real,
+    complex: numbers.Complex,
+    dict: dict,
+    list: list,
+    set: set,
+    tuple: tuple,
+    None: type(None),
+}
+
+
+def issubtype(left, right, recursive=True):
+    r"""
+    Check if the left-side type is a subtype of the right-side type.
+
+    If any of type is a composite type like `Union` and `TypeVar` with
+    bounds, it would be expanded into a list of types and check all
+    of left-side types are subtypes of either one from right-side types.
+    """
+    left = TYPE2ABC.get(left, left)
+    right = TYPE2ABC.get(right, right)
+
+    if right is Any or left == right:
+        return True
+
+    if isinstance(right, _GenericAlias):
+        if getattr(right, "__origin__", None) is Generic:
+            return True
+
+    if right is type(None):
+        return False
+
+    # Right-side type
+    constraints = _decompose_type(right)
+
+    if len(constraints) == 0 or Any in constraints:
+        return True
+
+    if left is Any:
+        return False
+
+    # Left-side type
+    variants = _decompose_type(left)
+
+    # all() will return True for empty variants
+    if len(variants) == 0:
+        return False
+
+    return all(
+        _issubtype_with_constraints(variant, constraints, recursive)
+        for variant in variants
+    )
+
+
+def _decompose_type(t, to_list=True):
+    if isinstance(t, TypeVar):
+        if t.__bound__ is not None:
+            ts = [t.__bound__]
+        else:
+            # For T_co, __constraints__ is ()
+            ts = list(t.__constraints__)
+    elif hasattr(t, "__origin__") and t.__origin__ == Union:
+        ts = t.__args__
+    else:
+        if not to_list:
+            return None
+        ts = [t]
+    # Ignored: Generator has incompatible item type "object"; expected "Type[Any]"
+    ts = [TYPE2ABC.get(_t, _t) for _t in ts]  # type: ignore[misc]
+    return ts
+
+
+def _issubtype_with_constraints(variant, constraints, recursive=True):
+    r"""
+    Check if the variant is a subtype of either one from constraints.
+
+    For composite types like `Union` and `TypeVar` with bounds, they
+    would be expanded for testing.
+    """
+    if variant in constraints:
+        return True
+
+    # [Note: Subtype for Union and TypeVar]
+    # Python typing is able to flatten Union[Union[...]] or Union[TypeVar].
+    # But it couldn't flatten the following scenarios:
+    #   - Union[int, TypeVar[Union[...]]]
+    #   - TypeVar[TypeVar[...]]
+    # So, variant and each constraint may be a TypeVar or a Union.
+    # In these cases, all of inner types from the variant are required to be
+    # extracted and verified as a subtype of any constraint. And, all of
+    # inner types from any constraint being a TypeVar or a Union are
+    # also required to be extracted and verified if the variant belongs to
+    # any of them.
+
+    # Variant
+    vs = _decompose_type(variant, to_list=False)
+
+    # Variant is TypeVar or Union
+    if vs is not None:
+        return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs)
+
+    # Variant is not TypeVar or Union
+    if hasattr(variant, "__origin__") and variant.__origin__ is not None:
+        v_origin = variant.__origin__
+        # In Python-3.9 typing library untyped generics do not have args
+        v_args = getattr(variant, "__args__", None)
+    else:
+        v_origin = variant
+        v_args = None
+
+    # Constraints
+    for constraint in constraints:
+        cs = _decompose_type(constraint, to_list=False)
+
+        # Constraint is TypeVar or Union
+        if cs is not None:
+            if _issubtype_with_constraints(variant, cs, recursive):
+                return True
+        # Constraint is not TypeVar or Union
+        else:
+            # __origin__ can be None for plain list, tuple, ... in Python 3.6
+            if hasattr(constraint, "__origin__") and constraint.__origin__ is not None:
+                c_origin = constraint.__origin__
+                if v_origin == c_origin:
+                    if not recursive:
+                        return True
+                    # In Python-3.9 typing library untyped generics do not have args
+                    c_args = getattr(constraint, "__args__", None)
+                    if c_args is None or len(c_args) == 0:
+                        return True
+                    if (
+                        v_args is not None
+                        and len(v_args) == len(c_args)
+                        and all(
+                            issubtype(v_arg, c_arg)
+                            for v_arg, c_arg in zip(v_args, c_args, strict=True)
+                        )
+                    ):
+                        return True
+            # Tuple[int] -> Tuple
+            else:
+                if v_origin == constraint:
+                    return True
+
+    return False
+
+
+def issubinstance(data, data_type):
+    if not issubtype(type(data), data_type, recursive=False):
+        return False
+
+    # In Python-3.9 typing library __args__ attribute is not defined for untyped generics
+    dt_args = getattr(data_type, "__args__", None)
+    if isinstance(data, tuple):
+        if dt_args is None or len(dt_args) == 0:
+            return True
+        if len(dt_args) != len(data):
+            return False
+        return all(issubinstance(d, t) for d, t in zip(data, dt_args, strict=True))
+    elif isinstance(data, (list, set)):
+        if dt_args is None or len(dt_args) == 0:
+            return True
+        t = dt_args[0]
+        return all(issubinstance(d, t) for d in data)
+    elif isinstance(data, dict):
+        if dt_args is None or len(dt_args) == 0:
+            return True
+        kt, vt = dt_args
+        return all(
+            issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items()
+        )
+
+    return True
+
+
+# [Note: TypeMeta and TypeAlias]
+# In order to keep compatibility for Python 3.6, use Meta for the typing.
+# TODO: When PyTorch drops the support for Python 3.6, it can be converted
+# into the Alias system and using `__class_getitem__` for DataPipe. The
+# typing system will gain benefit of performance and resolving metaclass
+# conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/
+
+
+class _DataPipeType:
+    r"""Save type annotation in `param`."""
+
+    def __init__(self, param) -> None:
+        self.param = param
+
+    def __repr__(self) -> str:
+        return _type_repr(self.param)
+
+    def __eq__(self, other):
+        if isinstance(other, _DataPipeType):
+            return self.param == other.param
+        return NotImplemented
+
+    def __hash__(self):
+        return hash(self.param)
+
+    def issubtype(self, other):
+        if isinstance(other.param, _GenericAlias):
+            if getattr(other.param, "__origin__", None) is Generic:
+                return True
+        if isinstance(other, _DataPipeType):
+            return issubtype(self.param, other.param)
+        if isinstance(other, type):
+            return issubtype(self.param, other)
+        raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}")
+
+    def issubtype_of_instance(self, other):
+        return issubinstance(other, self.param)
+
+
+# Default type for DataPipe without annotation
+_T_co = TypeVar("_T_co", covariant=True)
+# pyrefly: ignore [invalid-annotation]
+_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
+
+
+class _DataPipeMeta(GenericMeta):
+    r"""
+    Metaclass for `DataPipe`.
+
+    Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`.
+
+    Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`.
+    """
+
+    type: _DataPipeType
+
+    def __new__(cls, name, bases, namespace, **kwargs):
+        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+        # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
+        # pyrefly: ignore [no-access]
+        cls.__origin__ = None
+        if "type" in namespace:
+            return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+        namespace["__type_class__"] = False
+        #  For plain derived class without annotation
+        for base in bases:
+            if isinstance(base, _DataPipeMeta):
+                return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+        namespace.update(
+            {"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass}
+        )
+        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+    def __init__(self, name, bases, namespace, **kwargs) -> None:
+        super().__init__(name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+    # TODO: Fix isinstance bug
+    @_tp_cache
+    def _getitem_(self, params):
+        if params is None:
+            raise TypeError(f"{self.__name__}[t]: t can not be None")
+        if isinstance(params, str):
+            params = ForwardRef(params)
+        if not isinstance(params, tuple):
+            params = (params,)
+
+        msg = f"{self.__name__}[t]: t must be a type"
+        params = tuple(_type_check(p, msg) for p in params)
+
+        if isinstance(self.type.param, _GenericAlias):
+            orig = getattr(self.type.param, "__origin__", None)
+            if isinstance(orig, type) and orig is not Generic:
+                p = self.type.param[params]  # type: ignore[index]
+                t = _DataPipeType(p)
+                l = len(str(self.type)) + 2
+                name = self.__name__[:-l]
+                name = name + "[" + str(t) + "]"
+                bases = (self,) + self.__bases__
+                return self.__class__(
+                    name,
+                    bases,
+                    {
+                        "__init_subclass__": _dp_init_subclass,
+                        "type": t,
+                        "__type_class__": True,
+                    },
+                )
+
+        if len(params) > 1:
+            raise TypeError(
+                f"Too many parameters for {self} actual {len(params)}, expected 1"
+            )
+
+        t = _DataPipeType(params[0])
+
+        if not t.issubtype(self.type):
+            raise TypeError(
+                f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]"
+            )
+
+        # Types are equal, fast path for inheritance
+        if self.type == t:
+            return self
+
+        name = self.__name__ + "[" + str(t) + "]"
+        bases = (self,) + self.__bases__
+
+        return self.__class__(
+            name,
+            bases,
+            {"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t},
+        )
+
+    # TODO: Fix isinstance bug
+    def _eq_(self, other):
+        if not isinstance(other, _DataPipeMeta):
+            return NotImplemented
+        if self.__origin__ is None or other.__origin__ is None:  # type: ignore[has-type]
+            return self is other
+        return (
+            self.__origin__ == other.__origin__  # type: ignore[has-type]
+            and self.type == other.type
+        )
+
+    # TODO: Fix isinstance bug
+    def _hash_(self):
+        return hash((self.__name__, self.type))
+
+
+class _IterDataPipeMeta(_DataPipeMeta):
+    r"""
+    Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`.
+
+    Add various functions for behaviors specific to `IterDataPipe`.
+    """
+
+    def __new__(cls, name, bases, namespace, **kwargs):
+        if "reset" in namespace:
+            reset_func = namespace["reset"]
+
+            @functools.wraps(reset_func)
+            def conditional_reset(*args, **kwargs) -> None:
+                r"""
+                Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`.
+
+                This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call.
+                """
+                datapipe = args[0]
+                if datapipe._snapshot_state in (
+                    _SnapshotState.Iterating,
+                    _SnapshotState.NotStarted,
+                ):
+                    # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have
+                    # already begun iterating.
+                    datapipe._number_of_samples_yielded = 0
+                    datapipe._fast_forward_iterator = None
+                    reset_func(*args, **kwargs)
+                datapipe._snapshot_state = _SnapshotState.Iterating
+
+            namespace["reset"] = conditional_reset
+
+        if "__iter__" in namespace:
+            hook_iterator(namespace)
+        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+
+def _dp_init_subclass(sub_cls, *args, **kwargs) -> None:
+    # Add function for datapipe instance to reinforce the type
+    sub_cls.reinforce_type = reinforce_type
+
+    # TODO:
+    # - add global switch for type checking at compile-time
+
+    # Ignore internal type class
+    if getattr(sub_cls, "__type_class__", False):
+        return
+
+    # Check if the string type is valid
+    if isinstance(sub_cls.type.param, ForwardRef):
+        base_globals = sys.modules[sub_cls.__module__].__dict__
+        try:
+            param = _eval_type(sub_cls.type.param, base_globals, locals())
+            sub_cls.type.param = param
+        except TypeError as e:
+            raise TypeError(
+                f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing"
+            ) from e
+
+    if "__iter__" in sub_cls.__dict__:
+        iter_fn = sub_cls.__dict__["__iter__"]
+        hints = get_type_hints(iter_fn)
+        if "return" in hints:
+            return_hint = hints["return"]
+            # Plain Return Hint for Python 3.6
+            if return_hint == Iterator:
+                return
+            if not (
+                hasattr(return_hint, "__origin__")
+                and (
+                    return_hint.__origin__ == Iterator
+                    or return_hint.__origin__ == collections.abc.Iterator
+                )
+            ):
+                raise TypeError(
+                    "Expected 'Iterator' as the return annotation for `__iter__` of {}"
+                    ", but found {}".format(
+                        sub_cls.__name__, _type_repr(hints["return"])
+                    )
+                )
+            data_type = return_hint.__args__[0]
+            if not issubtype(data_type, sub_cls.type.param):
+                raise TypeError(
+                    f"Expected return type of '__iter__' as a subtype of {sub_cls.type},"
+                    f" but found {_type_repr(data_type)} for {sub_cls.__name__}"
+                )
+
+
+def reinforce_type(self, expected_type):
+    r"""
+    Reinforce the type for DataPipe instance.
+
+    And the 'expected_type' is required to be a subtype of the original type
+    hint to restrict the type requirement of DataPipe instance.
+    """
+    if isinstance(expected_type, tuple):
+        expected_type = tuple[expected_type]  # type: ignore[valid-type]
+    _type_check(expected_type, msg="'expected_type' must be a type")
+
+    if not issubtype(expected_type, self.type.param):
+        raise TypeError(
+            f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}"
+        )
+
+    self.type = _DataPipeType(expected_type)
+    return self
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97d8ce1695a5ded385753ee975ac9d39bc6709c7
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4649dc2019703e062c217ae3aa5e09ae693152c4
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1efc492aaa968b4c2715810aaaa8cc627989297d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dc0aff4da5ee7b48a581c265a28ebd6b09b3344
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..069fd909fea8229012f87dccfc6770f6f56166b0
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi
new file mode 100644
index 0000000000000000000000000000000000000000..7f49cc212383b2a635c36e1dc96c040d1d63868d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi
@@ -0,0 +1,746 @@
+# @generated by torch/utils/data/datapipes/gen_pyi.py from datapipe.pyi.in
+# mypy: allow-untyped-defs
+# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection
+# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt
+# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
+# classes/objects here, even though we are not injecting extra code into them at the moment.
+
+from collections.abc import Callable, Iterable, Iterator
+from typing import Any, Literal, TypeVar
+
+from torch.utils.data import Dataset, default_collate, IterableDataset
+from torch.utils.data.datapipes._hook_iterator import _SnapshotState
+from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
+
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+UNTRACABLE_DATAFRAME_PIPES: Any
+
+class DataChunk(list[_T]):
+    items: list[_T]
+    def __init__(self, items: Iterable[_T]) -> None: ...
+    def as_str(self, indent: str = "") -> str: ...
+    def __iter__(self) -> Iterator[_T]: ...
+    def raw_iterator(self) -> Iterator[_T]: ...
+
+class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
+    functions: dict[str, Callable] = ...
+    reduce_ex_hook: Callable | None = ...
+    getstate_hook: Callable | None = ...
+    str_hook: Callable | None = ...
+    repr_hook: Callable | None = ...
+    def __getattr__(self, attribute_name: Any): ...
+    @classmethod
+    def register_function(cls, function_name: Any, function: Any) -> None: ...
+    @classmethod
+    def register_datapipe_as_function(
+        cls,
+        function_name: Any,
+        cls_to_register: Any,
+    ): ...
+    def __getstate__(self): ...
+    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
+    @classmethod
+    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
+    @classmethod
+    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
+    # Functional form of 'BatcherMapDataPipe'
+    def batch(
+        self,
+        batch_size: int,
+        drop_last: bool = False,
+        wrapper_class: type[DataChunk] = DataChunk,
+    ) -> MapDataPipe:
+        r"""
+        Create mini-batches of data (functional name: ``batch``).
+
+        An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
+        or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
+
+        Args:
+            datapipe: Iterable DataPipe being batched
+            batch_size: The size of each batch
+            drop_last: Option to drop the last batch if it's not full
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp = SequenceWrapper(range(10))
+            >>> batch_dp = dp.batch(batch_size=2)
+            >>> list(batch_dp)
+            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
+        """
+    # Functional form of 'ConcaterMapDataPipe'
+    def concat(self, *datapipes: MapDataPipe) -> MapDataPipe:
+        r"""
+        Concatenate multiple Map DataPipes (functional name: ``concat``).
+
+        The new index of is the cumulative sum of source DataPipes.
+        For example, if there are 2 source DataPipes both with length 5,
+        index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
+        elements of the first DataPipe, and 5 to 9 would refer to elements
+        of the second DataPipe.
+
+        Args:
+            datapipes: Map DataPipes being concatenated
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp1 = SequenceWrapper(range(3))
+            >>> dp2 = SequenceWrapper(range(3))
+            >>> concat_dp = dp1.concat(dp2)
+            >>> list(concat_dp)
+            [0, 1, 2, 0, 1, 2]
+        """
+    # Functional form of 'MapperMapDataPipe'
+    def map(self, fn: Callable = ...) -> MapDataPipe:
+        r"""
+        Apply the input function over each item from the source DataPipe (functional name: ``map``).
+
+        The function can be any regular Python function or partial object. Lambda
+        function is not recommended as it is not supported by pickle.
+
+        Args:
+            datapipe: Source MapDataPipe
+            fn: Function being applied to each item
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
+            >>> def add_one(x):
+            ...     return x + 1
+            >>> dp = SequenceWrapper(range(10))
+            >>> map_dp_1 = dp.map(add_one)
+            >>> list(map_dp_1)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+            >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
+            >>> list(map_dp_2)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        """
+    # Functional form of 'ShufflerIterDataPipe'
+    def shuffle(self, *, indices: list | None = None) -> IterDataPipe:
+        r"""
+        Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
+
+        When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
+        set up random seed are different based on :attr:`num_workers`.
+
+        For single-process mode (:attr:`num_workers == 0`), the random seed is set before
+        the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
+        mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
+        for each worker process.
+
+        Args:
+            datapipe: MapDataPipe being shuffled
+            indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp = SequenceWrapper(range(10))
+            >>> shuffle_dp = dp.shuffle().set_seed(0)
+            >>> list(shuffle_dp)
+            [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
+            >>> list(shuffle_dp)
+            [6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
+            >>> # Reset seed for Shuffler
+            >>> shuffle_dp = shuffle_dp.set_seed(0)
+            >>> list(shuffle_dp)
+            [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
+
+        Note:
+            Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
+            ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
+            the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
+            of data during data-processing.
+        """
+    # Functional form of 'ZipperMapDataPipe'
+    def zip(self, *datapipes: MapDataPipe[_T_co]) -> MapDataPipe:
+        r"""
+        Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
+
+        This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
+
+        Args:
+            *datapipes: Map DataPipes being aggregated
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp1 = SequenceWrapper(range(3))
+            >>> dp2 = SequenceWrapper(range(10, 13))
+            >>> zip_dp = dp1.zip(dp2)
+            >>> list(zip_dp)
+            [(0, 10), (1, 11), (2, 12)]
+        """
+
+class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
+    functions: dict[str, Callable] = ...
+    reduce_ex_hook: Callable | None = ...
+    getstate_hook: Callable | None = ...
+    str_hook: Callable | None = ...
+    repr_hook: Callable | None = ...
+    _number_of_samples_yielded: int = ...
+    _snapshot_state: _SnapshotState = _SnapshotState.Iterating  # noqa: PYI015
+    _fast_forward_iterator: Iterator | None = ...
+    def __getattr__(self, attribute_name: Any): ...
+    @classmethod
+    def register_function(cls, function_name: Any, function: Any) -> None: ...
+    @classmethod
+    def register_datapipe_as_function(
+        cls,
+        function_name: Any,
+        cls_to_register: Any,
+        enable_df_api_tracing: bool = ...,
+    ): ...
+    def __getstate__(self): ...
+    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
+    @classmethod
+    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
+    @classmethod
+    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
+    # Functional form of 'BatcherIterDataPipe'
+    def batch(
+        self,
+        batch_size: int,
+        drop_last: bool = False,
+        wrapper_class: type[DataChunk] = DataChunk,
+    ) -> IterDataPipe:
+        r"""
+        Creates mini-batches of data (functional name: ``batch``).
+
+        An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the
+        last batch if ``drop_last`` is set to ``False``.
+
+        Args:
+            datapipe: Iterable DataPipe being batched
+            batch_size: The size of each batch
+            drop_last: Option to drop the last batch if it's not full
+            wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
+                defaults to ``DataChunk``
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp = IterableWrapper(range(10))
+            >>> dp = dp.batch(batch_size=3, drop_last=True)
+            >>> list(dp)
+            [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+        """
+    # Functional form of 'CollatorIterDataPipe'
+    def collate(
+        self,
+        conversion: Callable[..., Any]| dict[str | Any, Callable | Any]| None = default_collate,
+        collate_fn: Callable | None = None,
+    ) -> IterDataPipe:  # fmt: skip
+        r"""
+        Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
+
+        By default, it uses :func:`torch.utils.data.default_collate`.
+
+        .. note::
+            While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
+            default behavior and `functools.partial` to specify any additional arguments.
+
+        Args:
+            datapipe: Iterable DataPipe being collated
+            collate_fn: Customized collate function to collect and combine data or a batch of data.
+                Default function collates to Tensor(s) based on data type.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> # Convert integer data to float Tensor
+            >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
+            ...     def __init__(self, start, end):
+            ...         super(MyIterDataPipe).__init__()
+            ...         assert end > start, "this example only works with end >= start"
+            ...         self.start = start
+            ...         self.end = end
+            ...
+            ...     def __iter__(self):
+            ...         return iter(range(self.start, self.end))
+            ...
+            ...     def __len__(self):
+            ...         return self.end - self.start
+            >>> ds = MyIterDataPipe(start=3, end=7)
+            >>> print(list(ds))
+            [3, 4, 5, 6]
+            >>> def collate_fn(batch):
+            ...     return torch.tensor(batch, dtype=torch.float)
+            >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
+            >>> print(list(collated_ds))
+            [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
+        """
+    # Functional form of 'ConcaterIterDataPipe'
+    def concat(self, *datapipes: IterDataPipe) -> IterDataPipe:
+        r"""
+        Concatenates multiple Iterable DataPipes (functional name: ``concat``).
+
+        The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones.
+
+        Args:
+            datapipes: Iterable DataPipes being concatenated
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> import random
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp1 = IterableWrapper(range(3))
+            >>> dp2 = IterableWrapper(range(5))
+            >>> list(dp1.concat(dp2))
+            [0, 1, 2, 0, 1, 2, 3, 4]
+        """
+    # Functional form of 'DemultiplexerIterDataPipe'
+    def demux(
+        self,
+        num_instances: int,
+        classifier_fn: Callable[[_T_co], int | None],
+        drop_none: bool = False,
+        buffer_size: int = 1000,
+    ) -> list[IterDataPipe]:
+        r"""
+        Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``).
+
+        A list of the child DataPipes is returned from this operation.
+
+        Args:
+            datapipe: Iterable DataPipe being filtered
+            num_instances: number of instances of the DataPipe to create
+            classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None``
+            drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
+            buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
+                DataPipes while waiting for their values to be yielded.
+                Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
+
+        Examples:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> def odd_or_even(n):
+            ...     return n % 2
+            >>> source_dp = IterableWrapper(range(5))
+            >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
+            >>> list(dp1)
+            [0, 2, 4]
+            >>> list(dp2)
+            [1, 3]
+            >>> # It can also filter out any element that gets `None` from the `classifier_fn`
+            >>> def odd_or_even_no_zero(n):
+            ...     return n % 2 if n != 0 else None
+            >>> dp1, dp2 = source_dp.demux(
+            ...     num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True
+            ... )
+            >>> list(dp1)
+            [2, 4]
+            >>> list(dp2)
+            [1, 3]
+        """
+    # Functional form of 'FilterIterDataPipe'
+    def filter(self, filter_fn: Callable, input_col=None) -> IterDataPipe:
+        r"""
+        Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
+
+        Args:
+            datapipe: Iterable DataPipe being filtered
+            filter_fn: Customized function mapping an element to a boolean.
+            input_col: Index or indices of data which ``filter_fn`` is applied, such as:
+
+                - ``None`` as default to apply ``filter_fn`` to the data directly.
+                - Integer(s) is used for list/tuple.
+                - Key(s) is used for dict.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> def is_even(n):
+            ...     return n % 2 == 0
+            >>> dp = IterableWrapper(range(5))
+            >>> filter_dp = dp.filter(filter_fn=is_even)
+            >>> list(filter_dp)
+            [0, 2, 4]
+        """
+    # Functional form of 'ForkerIterDataPipe'
+    def fork(
+        self,
+        num_instances: int,
+        buffer_size: int = 1000,
+        copy: Literal["shallow", "deep"] | None = None,
+    ) -> list[IterDataPipe]:
+        r"""
+        Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``).
+
+        Args:
+            datapipe: Iterable DataPipe being copied
+            num_instances: number of instances of the datapipe to create
+            buffer_size: this restricts how far ahead the leading child DataPipe
+               can read relative to the slowest child DataPipe.
+               Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
+            copy: copy strategy to use for items yielded by each branch. Supported
+                options are ``None`` for no copying, ``"shallow"`` for shallow object
+                copies, and ``"deep"`` for deep object copies. Defaults to ``None``.
+
+        Note:
+            All branches of the forked pipeline return the identical object unless
+            the copy parameter is supplied. If the object is mutable or contains
+            mutable objects, changing them in one branch will affect all others.
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> source_dp = IterableWrapper(range(5))
+            >>> dp1, dp2 = source_dp.fork(num_instances=2)
+            >>> list(dp1)
+            [0, 1, 2, 3, 4]
+            >>> list(dp2)
+            [0, 1, 2, 3, 4]
+        """
+    # Functional form of 'GrouperIterDataPipe'
+    def groupby(
+        self,
+        group_key_fn: Callable[[_T_co], Any],
+        *,
+        keep_key: bool = False,
+        buffer_size: int = 10000,
+        group_size: int | None = None,
+        guaranteed_group_size: int | None = None,
+        drop_remaining: bool = False,
+    ) -> IterDataPipe:
+        r"""
+        Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``.
+
+        (functional name: ``groupby``).
+
+        The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
+        will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
+        the DataPipe will yield the largest batch with the same key, provided that its size is larger
+        than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
+
+        After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
+        will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
+
+        Args:
+            datapipe: Iterable datapipe to be grouped
+            group_key_fn: Function used to generate group key from the data of the source datapipe
+            keep_key: Option to yield the matching key along with the items in a tuple,
+                resulting in `(key, [items])` otherwise returning [items]
+            buffer_size: The size of buffer for ungrouped data
+            group_size: The max size of each group, a batch is yielded as soon as it reaches this size
+            guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
+            drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
+                when the buffer is full
+
+        Example:
+            >>> import os
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> def group_fn(file):
+            ...     return os.path.basename(file).split(".")[0]
+            >>> source_dp = IterableWrapper(
+            ...     ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]
+            ... )
+            >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
+            >>> list(dp0)
+            [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
+            >>> # A group is yielded as soon as its size equals to `group_size`
+            >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
+            >>> list(dp1)
+            [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
+            >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
+            >>> dp2 = source_dp.groupby(
+            ...     group_key_fn=group_fn,
+            ...     buffer_size=3,
+            ...     group_size=3,
+            ...     guaranteed_group_size=2,
+            ... )
+            >>> list(dp2)
+            [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
+        """
+    # Functional form of 'FileListerIterDataPipe'
+    def list_files(
+        self,
+        masks: str | list[str] = "",
+        *,
+        recursive: bool = False,
+        abspath: bool = False,
+        non_deterministic: bool = False,
+        length: int = -1,
+    ) -> IterDataPipe:
+        r"""
+        Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory.
+
+        Multiple root directories can be provided (functional name: ``list_files``).
+
+        Args:
+            root: Root directory or a sequence of root directories
+            masks: Unix style filter string or string list for filtering file name(s)
+            recursive: Whether to return pathname from nested directories or not
+            abspath: Whether to return relative pathname or absolute pathname
+            non_deterministic: Whether to return pathname in sorted order or not.
+                If ``False``, the results yielded from each root directory will be sorted
+            length: Nominal length of the datapipe
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import FileLister
+            >>> dp = FileLister(root=".", recursive=True)
+            >>> list(dp)
+            ['example.py', './data/data.tar']
+        """
+    # Functional form of 'MapperIterDataPipe'
+    def map(
+        self,
+        fn: Callable,
+        input_col=None,
+        output_col=None,
+    ) -> IterDataPipe:
+        r"""
+        Applies a function over each item from the source DataPipe (functional name: ``map``).
+
+        The function can be any regular Python function or partial object. Lambda
+        function is not recommended as it is not supported by pickle.
+
+        Args:
+            datapipe: Source Iterable DataPipe
+            fn: Function being applied over each item
+            input_col: Index or indices of data which ``fn`` is applied, such as:
+
+                - ``None`` as default to apply ``fn`` to the data directly.
+                - Integer(s) is used for list/tuple.
+                - Key(s) is used for dict.
+
+            output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
+                only when ``input_col`` is not ``None``
+
+                - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
+                  multiple indices, the left-most one is used, and other indices will be removed.
+                - Integer is used for list/tuple. ``-1`` represents to append result at the end.
+                - Key is used for dict. New key is acceptable.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
+            >>> def add_one(x):
+            ...     return x + 1
+            >>> dp = IterableWrapper(range(10))
+            >>> # Invocation via functional form is preferred
+            ... map_dp_1 = dp.map(add_one)
+            >>> list(map_dp_1)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+            >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
+            >>> # Use `functools.partial` or explicitly define the function instead
+            >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
+            >>> list(map_dp_2)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        """
+    # Functional form of 'MultiplexerIterDataPipe'
+    def mux(self, *datapipes) -> IterDataPipe:
+        r"""
+        Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``).
+
+        As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
+        and so on. It ends when the shortest input DataPipe is exhausted.
+
+        Args:
+            datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp1, dp2, dp3 = (
+            ...     IterableWrapper(range(3)),
+            ...     IterableWrapper(range(10, 15)),
+            ...     IterableWrapper(range(20, 25)),
+            ... )
+            >>> list(dp1.mux(dp2, dp3))
+            [0, 10, 20, 1, 11, 21, 2, 12, 22]
+        """
+    # Functional form of 'FileOpenerIterDataPipe'
+    def open_files(
+        self,
+        mode: str = "r",
+        encoding: str | None = None,
+        length: int = -1,
+    ) -> IterDataPipe:
+        r"""
+        Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``).
+
+        Args:
+            datapipe: Iterable datapipe that provides pathnames
+            mode: An optional string that specifies the mode in which
+                the file is opened by ``open()``. It defaults to ``r``, other options are
+                ``b`` for reading in binary mode and ``t`` for text mode.
+            encoding: An optional string that specifies the encoding of the
+                underlying file. It defaults to ``None`` to match the default encoding of ``open``.
+            length: Nominal length of the datapipe
+
+        Note:
+            The opened file handles will be closed by Python's GC periodically. Users can choose
+            to close them explicitly.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import (
+            ...     FileLister,
+            ...     FileOpener,
+            ...     StreamReader,
+            ... )
+            >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt"))
+            >>> dp = FileOpener(dp)
+            >>> dp = StreamReader(dp)
+            >>> list(dp)
+            [('./abc.txt', 'abc')]
+        """
+    # Functional form of 'StreamReaderIterDataPipe'
+    def read_from_stream(self, chunk: int | None = None) -> IterDataPipe:
+        r"""
+        Given IO streams and their label names, yield bytes with label name as tuple.
+
+        (functional name: ``read_from_stream``).
+
+        Args:
+            datapipe: Iterable DataPipe provides label/URL and byte stream
+            chunk: Number of bytes to be read from stream per iteration.
+                If ``None``, all bytes will be read until the EOF.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
+            >>> from io import StringIO
+            >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
+            >>> list(StreamReader(dp, chunk=1))
+            [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
+        """
+    # Functional form of 'RoutedDecoderIterDataPipe'
+    def routed_decode(
+        self,
+        *handlers: Callable,
+        key_fn: Callable = ...,
+    ) -> IterDataPipe:
+        r"""
+        Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
+
+        (functional name: ``routed_decode``)
+
+        Args:
+            datapipe: Iterable datapipe that provides pathname and binary stream in tuples
+            handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
+                handlers will be set as default. If multiple handles are provided, the priority
+                order follows the order of handlers (the first handler has the top priority)
+            key_fn: Function for decoder to extract key from pathname to dispatch handlers.
+                Default is set to extract file extension from pathname
+
+        Note:
+            When ``key_fn`` is specified returning anything other than extension, the default
+            handler will not work and users need to specify custom handler. Custom handler
+            could use regex to determine the eligibility to handle data.
+        """
+    # Functional form of 'ShardingFilterIterDataPipe'
+    def sharding_filter(self, sharding_group_filter=None) -> IterDataPipe:
+        r"""
+        Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``).
+
+        After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
+        original DataPipe, where `n` equals to the number of instances.
+
+        Args:
+            source_datapipe: Iterable DataPipe that will be sharded
+        """
+    # Functional form of 'ShufflerIterDataPipe'
+    def shuffle(
+        self,
+        *,
+        buffer_size: int = 10000,
+        unbatch_level: int = 0,
+    ) -> IterDataPipe:
+        r"""
+        Shuffle the input DataPipe with a buffer (functional name: ``shuffle``).
+
+        The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then,
+        each item will be yielded from the buffer by reservoir sampling via iterator.
+
+        ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
+        datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
+        ``buffer_size`` is required to be greater than or equal to the size of datapipe.
+
+        When it is used with :class:`torch.utils.data.DataLoader`, the methods to
+        set up random seed are different based on :attr:`num_workers`.
+
+        For single-process mode (:attr:`num_workers == 0`), the random seed is set before
+        the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
+        mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
+        for each worker process.
+
+        Args:
+            datapipe: The IterDataPipe being shuffled
+            buffer_size: The buffer size for shuffling (default to ``10000``)
+            unbatch_level: Specifies if it is necessary to unbatch source data before
+                applying the shuffle
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp = IterableWrapper(range(10))
+            >>> shuffle_dp = dp.shuffle()
+            >>> list(shuffle_dp)
+            [0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
+        """
+    # Functional form of 'UnBatcherIterDataPipe'
+    def unbatch(self, unbatch_level: int = 1) -> IterDataPipe:
+        r"""
+        Undos batching of data (functional name: ``unbatch``).
+
+        In other words, it flattens the data up to the specified level within a batched DataPipe.
+
+        Args:
+            datapipe: Iterable DataPipe being un-batched
+            unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
+                it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
+            >>> dp1 = source_dp.unbatch()
+            >>> list(dp1)
+            [[0, 1], [2], [3, 4], [5], [6]]
+            >>> dp2 = source_dp.unbatch(unbatch_level=2)
+            >>> list(dp2)
+            [0, 1, 2, 3, 4, 5, 6]
+        """
+    # Functional form of 'ZipperIterDataPipe'
+    def zip(self, *datapipes: IterDataPipe) -> IterDataPipe:
+        r"""
+        Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
+
+        The output is stopped as soon as the shortest input DataPipe is exhausted.
+
+        Args:
+            *datapipes: Iterable DataPipes being aggregated
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp1, dp2, dp3 = (
+            ...     IterableWrapper(range(5)),
+            ...     IterableWrapper(range(10, 15)),
+            ...     IterableWrapper(range(20, 25)),
+            ... )
+            >>> list(dp1.zip(dp2, dp3))
+            [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
+        """
+
+class DFIterDataPipe(IterDataPipe):
+    def _is_dfpipe(self): ...
+    def __iter__(self): ...
+
+class _DataPipeSerializationWrapper:
+    def __init__(self, datapipe): ...
+    def __getstate__(self): ...
+    def __setstate__(self, state): ...
+    def __len__(self): ...
+
+class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
+    def __iter__(self): ...
+
+class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
+    def __getitem__(self, idx): ...
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee81b7d2862abf7c959334ff6a97691143ed69c8
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/constants.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/constants.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5dfec19e61e30b3b83eef3bd501de5cc9dce0567
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/constants.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1a3546cdda9c9f7cb0765a7795831ed43328bb9
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/version.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/version.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7824c03e797583f195e71359aca947dfaf1f67fb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/hipify/__pycache__/version.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/jit/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/jit/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78e1f2975917bb37343c96175c9db1e528eab1dc
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/jit/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3192aae6e2232184f67a6562fbe412269b7a4a37
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d7c59bbb12a8522951bb335e45710ef163146ce
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9d0e9a486b10e3d0bab6ab303fff28dab016af1
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29820be1f95050e11e5d088ba3b37c3fa4510d09
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6692898bbfc0331790ddc4a7ae37f6de2b68b40
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70ff583ba8b652cafa6f53abfcf64cc61e450cbf
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7ec95d4cd8c8c92d0537bfa7fdeccc2a145be5e
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80d176068caccff9249f46d9182776b360b83f63
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf7c97f74e4e98b53447831a2563307bd26ad895
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce5f29d32a3f870c195532fc6c9819a17d4237e6
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16bf96d9255218ef541d313e292e3a6d701ce43d
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88c949643f25f7975e13d1c3919800957397badb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14133092c5ba5e90739a017fd47450197263bacd
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89e5e2010bae775738b36674e41fefccfe33cb4f
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/viz/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/viz/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e58e1d0f5abe1f570539e513de4b44ecfc3c351c
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/viz/__pycache__/__init__.cpython-312.pyc differ
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e63327f1b7e29fa435e3160218e2fa72036523bb
Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-312.pyc differ